* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file block_mmad_fixpipe_dequant.h
* \brief
*/
#ifndef BLOCK_MMAD_FIXPIPE_DEQUANT_H
#define BLOCK_MMAD_FIXPIPE_DEQUANT_H
#include "../policy/dispatch_policy.h"
#include "../utils/common_utils.h"
#include "../utils/layout_utils.h"
#include "../utils/tuple_utils.h"
namespace Cgmct {
namespace Gemm {
namespace Block {
namespace {
constexpr static uint64_t HALF_L0C_SIZE = AscendC::TOTAL_L0C_SIZE / DOUBLE_BUFFER_COUNT / sizeof(float);
constexpr static uint64_t BLOCK_REDUCE_CUBE = 32UL;
constexpr static uint64_t IDX_M_TILE_IDX = 0UL;
constexpr static uint64_t IDX_N_TILE_IDX = 1UL;
constexpr static uint64_t IDX_M_IDX = 0UL;
constexpr static uint64_t IDX_N_IDX = 1UL;
constexpr static uint64_t IDX_K_IDX = 2UL;
constexpr static uint64_t EVEN_FACTOR = 2UL;
constexpr static uint64_t B8_MIN_STEP = 2UL;
constexpr static uint16_t INPUT_BUFFER_FLAG_0 = 0;
constexpr static uint16_t INPUT_BUFFER_FLAG_1 = 1;
constexpr static uint16_t INPUT_BUFFER_FLAG_2 = 2;
constexpr static uint16_t INPUT_BUFFER_FLAG_3 = 3;
struct TileL1L0Param {
uint64_t curM = 0;
uint64_t curN = 0;
uint64_t curAlignM = 0;
uint64_t curAlignN = 0;
uint64_t curGmAKL1 = 0;
uint64_t curGmBKL1 = 0;
uint64_t curPadAKL1 = 0;
uint64_t curPadBKL1 = 0;
uint64_t curKL0 = 0;
};
}
template <class DispatchPolicy_, class L1TileShape_, class L0TileShape_, class AType_, class LayoutA_, class BType_,
class LayoutB_, class CType_, class LayoutC_, class BiasType_, class LayoutBias_, class TileCopy_,
class Enable = void>
class BlockMmadFixpipeDequant {
static_assert(AscendC::Std::always_false_v<DispatchPolicy_>, "Should not be here!");
};
template <class DispatchPolicy_, class L1TileShape_, class L0TileShape_, class AType_, class LayoutA_, class BType_,
class LayoutB_, class CType_, class LayoutC_, class BiasType_, class LayoutBias_, class TileCopy_>
class BlockMmadFixpipeDequant<
DispatchPolicy_, L1TileShape_, L0TileShape_, AType_, LayoutA_, BType_, LayoutB_, CType_, LayoutC_, BiasType_,
LayoutBias_, TileCopy_,
AscendC::Std::enable_if_t<AscendC::Std::is_base_of_v<MatmulWithScale<>, DispatchPolicy_>>> {
public:
using AType = AType_;
using BType = BType_;
using CType = CType_;
using LayoutA = LayoutA_;
using LayoutB = LayoutB_;
using L1TileShape = L1TileShape_;
using L0TileShape = L0TileShape_;
using BiasType = BiasType_;
using DispatchPolicy = DispatchPolicy_;
using TupleShape = AscendC::Shape<int64_t, int64_t, int64_t>;
using BlockShape = AscendC::Shape<int64_t, int64_t, int64_t>;
static constexpr bool transA = TagToTrans<LayoutA>::value;
static constexpr bool transB = TagToTrans<LayoutB>::value;
static constexpr CubeFormat formatB = TagToFormat<LayoutB>::format;
constexpr static uint64_t HALF_L0_SIZE = L0A_SIZE / DOUBLE_BUFFER_COUNT;
constexpr static int32_t C0_SIZE = AscendC::AuxGetC0Size<AType>();
struct Params {
GM_ADDR aGmAddr{nullptr};
GM_ADDR bGmAddr{nullptr};
GM_ADDR x2ScaleGmAddr{nullptr};
GM_ADDR x1ScaleGmAddr{nullptr};
GM_ADDR cGmAddr{nullptr};
GM_ADDR biasGmAddr{nullptr};
GM_ADDR groupListGmAddr{nullptr};
};
struct L1Params {
uint64_t kAL1;
uint64_t kBL1;
uint64_t l1BufNum;
};
__aicore__ inline BlockMmadFixpipeDequant()
{
AscendC::SetFlag<AscendC::HardEvent::MTE1_MTE2>(INPUT_BUFFER_FLAG_0);
AscendC::SetFlag<AscendC::HardEvent::MTE1_MTE2>(INPUT_BUFFER_FLAG_1);
AscendC::SetFlag<AscendC::HardEvent::MTE1_MTE2>(INPUT_BUFFER_FLAG_2);
AscendC::SetFlag<AscendC::HardEvent::MTE1_MTE2>(INPUT_BUFFER_FLAG_3);
AscendC::SetFlag<AscendC::HardEvent::FIX_M>(INPUT_BUFFER_FLAG_0);
AscendC::SetFlag<AscendC::HardEvent::FIX_M>(INPUT_BUFFER_FLAG_1);
AscendC::SetMMLayoutTransform(true);
}
__aicore__ inline ~BlockMmadFixpipeDequant()
{
AscendC::WaitFlag<AscendC::HardEvent::MTE1_MTE2>(INPUT_BUFFER_FLAG_0);
AscendC::WaitFlag<AscendC::HardEvent::MTE1_MTE2>(INPUT_BUFFER_FLAG_1);
AscendC::WaitFlag<AscendC::HardEvent::MTE1_MTE2>(INPUT_BUFFER_FLAG_2);
AscendC::WaitFlag<AscendC::HardEvent::MTE1_MTE2>(INPUT_BUFFER_FLAG_3);
AscendC::WaitFlag<AscendC::HardEvent::FIX_M>(INPUT_BUFFER_FLAG_0);
AscendC::WaitFlag<AscendC::HardEvent::FIX_M>(INPUT_BUFFER_FLAG_1);
AscendC::SetMMLayoutTransform(false);
}
public:
__aicore__ inline void Init(const TupleShape &problemShape, const BlockShape &l0TileShape, const L1Params &l1Params,
bool isBias, bool dbL0C)
{
m_ = Get<IDX_M_IDX>(problemShape);
n_ = Get<IDX_N_IDX>(problemShape);
k_ = Get<IDX_K_IDX>(problemShape);
baseM_ = Get<IDX_M_IDX>(l0TileShape);
baseN_ = Get<IDX_N_IDX>(l0TileShape);
baseK_ = Get<IDX_K_IDX>(l0TileShape);
kAL1_ = l1Params.kAL1;
kBL1_ = l1Params.kBL1;
orderAL1BL1_ = l1Params.kAL1 >= l1Params.kBL1;
minKL1_ = Min(l1Params.kAL1, l1Params.kBL1);
isBias_ = isBias;
l1BufNum_ = l1Params.l1BufNum;
enableL0cPingPong_ = dbL0C;
aL1OneBuffer_ = baseM_ * kAL1_;
bL1OneBuffer_ = baseN_ * kBL1_;
for (int32_t bufferId = 0; bufferId < l1BufNum_; bufferId++) {
uint64_t l1Offset = (AscendC::TOTAL_L1_SIZE >> 1) * (bufferId & 1);
l1BufferAOffset_[bufferId] = l1Offset + aL1OneBuffer_ * (bufferId >> 1);
l1BufferBOffset_[bufferId] = l1Offset + aL1OneBuffer_ * (l1BufNum_ >> 1) + bL1OneBuffer_ * (bufferId >> 1);
}
}
__aicore__ inline void UpdateParamsForNextProblem(const TupleShape &problemShape)
{
m_ = Get<IDX_M_IDX>(problemShape);
n_ = Get<IDX_N_IDX>(problemShape);
k_ = Get<IDX_K_IDX>(problemShape);
}
__aicore__ inline void CopyInA1(const AscendC::GlobalTensor<AType> &aGlobal,
const AscendC::LocalTensor<AType> &al1Local, const TileL1L0Param &tileL1L0Param)
{
AscendC::Nd2NzParams nd2nzParams;
nd2nzParams.ndNum = 1;
uint64_t nDim = transA ? tileL1L0Param.curGmAKL1 : tileL1L0Param.curM;
uint64_t dDim = transA ? tileL1L0Param.curM : tileL1L0Param.curGmAKL1;
nd2nzParams.nValue = nDim;
nd2nzParams.dValue = dDim;
nd2nzParams.srcNdMatrixStride = 1;
nd2nzParams.srcDValue = transA ? m_ : k_;
nd2nzParams.dstNzC0Stride = transA ? tileL1L0Param.curPadAKL1 : tileL1L0Param.curAlignM;
nd2nzParams.dstNzNStride = 1;
nd2nzParams.dstNzMatrixStride = 1;
AscendC::DataCopy(al1Local, aGlobal, nd2nzParams);
}
__aicore__ inline void CopyInB1(const AscendC::GlobalTensor<BType> &bGlobal,
const AscendC::LocalTensor<BType> &bl1Local, const TileL1L0Param &tileL1L0Param)
{
AscendC::Nd2NzParams nd2nzParams;
nd2nzParams.ndNum = 1;
uint64_t nDim = transB ? tileL1L0Param.curN : tileL1L0Param.curGmBKL1;
uint64_t dDim = transB ? tileL1L0Param.curGmBKL1 : tileL1L0Param.curN;
nd2nzParams.nValue = nDim;
nd2nzParams.dValue = dDim;
nd2nzParams.srcNdMatrixStride = 1;
nd2nzParams.srcDValue = transB ? k_ : n_;
nd2nzParams.dstNzC0Stride = transB ? tileL1L0Param.curAlignN : tileL1L0Param.curPadBKL1;
nd2nzParams.dstNzNStride = 1;
nd2nzParams.dstNzMatrixStride = 1;
AscendC::DataCopy(bl1Local, bGlobal, nd2nzParams);
}
__aicore__ inline void CopyInL0A(const AscendC::LocalTensor<AType> &l0aLocal,
const AscendC::LocalTensor<AType> &al1Local, const TileL1L0Param &tileL1L0Param,
uint64_t kL0L1Off)
{
AscendC::LoadData2DParamsV2 loadDataParams;
uint64_t m1 = Cgmct::Gemm::CeilDiv(tileL1L0Param.curM, AscendC::BLOCK_CUBE);
if constexpr (transA) {
loadDataParams.mStartPosition = Cgmct::Gemm::CeilDiv(kL0L1Off, AscendC::BLOCK_CUBE);
loadDataParams.kStartPosition = 0;
if ((m1 & 1) == 0) {
loadDataParams.mStep = Cgmct::Gemm::CeilAlign(
Cgmct::Gemm::CeilDiv(tileL1L0Param.curKL0, AscendC::BLOCK_CUBE), EVEN_FACTOR);
} else {
loadDataParams.mStep = B8_MIN_STEP;
}
loadDataParams.kStep = Cgmct::Gemm::CeilDiv(tileL1L0Param.curM, C0_SIZE);
loadDataParams.srcStride = Cgmct::Gemm::CeilDiv(tileL1L0Param.curPadAKL1, AscendC::BLOCK_CUBE);
loadDataParams.dstStride = m1;
loadDataParams.ifTranspose = true;
}
AscendC::LoadData(l0aLocal, al1Local, loadDataParams);
if constexpr (transA) {
if ((m1 & 1) != 0) {
AscendC::PipeBarrier<PIPE_MTE1>();
uint64_t loadTimes =
Cgmct::Gemm::CeilDiv(Cgmct::Gemm::CeilDiv(tileL1L0Param.curKL0, AscendC::BLOCK_CUBE), B8_MIN_STEP);
for (uint64_t i = 1; i < loadTimes; i++) {
loadDataParams.mStartPosition =
B8_MIN_STEP * i + Cgmct::Gemm::CeilDiv(kL0L1Off, AscendC::BLOCK_CUBE);
AscendC::LoadData(l0aLocal[i * m1 * AscendC::BLOCK_CUBE * C0_SIZE], al1Local, loadDataParams);
AscendC::PipeBarrier<PIPE_MTE1>();
}
}
}
}
__aicore__ inline void CopyInL0B(const AscendC::LocalTensor<BType> &l0bLocal,
const AscendC::LocalTensor<BType> &bl1Local, const TileL1L0Param &tileL1L0Param,
uint64_t kL0L1Off)
{
AscendC::LoadData2DParamsV2 loadDataParams;
uint64_t n1 = Cgmct::Gemm::CeilDiv(tileL1L0Param.curN, AscendC::BLOCK_CUBE);
if constexpr (!transB) {
loadDataParams.mStartPosition = Cgmct::Gemm::CeilDiv(kL0L1Off, AscendC::BLOCK_CUBE);
loadDataParams.kStartPosition = 0;
if ((n1 & 1) == 0) {
loadDataParams.mStep = Cgmct::Gemm::CeilAlign(
Cgmct::Gemm::CeilDiv(tileL1L0Param.curKL0, AscendC::BLOCK_CUBE), EVEN_FACTOR);
} else {
loadDataParams.mStep = B8_MIN_STEP;
}
loadDataParams.kStep = Cgmct::Gemm::CeilDiv(tileL1L0Param.curN, C0_SIZE);
loadDataParams.srcStride = Cgmct::Gemm::CeilDiv(tileL1L0Param.curPadBKL1, AscendC::BLOCK_CUBE);
loadDataParams.dstStride = n1;
loadDataParams.ifTranspose = true;
}
AscendC::LoadData(l0bLocal, bl1Local, loadDataParams);
if constexpr (!transB) {
if ((n1 & 1) != 0) {
AscendC::PipeBarrier<PIPE_MTE1>();
uint64_t loadTimes =
Cgmct::Gemm::CeilDiv(Cgmct::Gemm::CeilDiv(tileL1L0Param.curKL0, AscendC::BLOCK_CUBE), B8_MIN_STEP);
for (uint64_t i = 1; i < loadTimes; i++) {
loadDataParams.mStartPosition =
B8_MIN_STEP * i + Cgmct::Gemm::CeilDiv(kL0L1Off, AscendC::BLOCK_CUBE);
AscendC::LoadData(l0bLocal[i * n1 * AscendC::BLOCK_CUBE * C0_SIZE], bl1Local, loadDataParams);
AscendC::PipeBarrier<PIPE_MTE1>();
}
}
}
}
__aicore__ inline void CopyFixpipeScalar(const AscendC::GlobalTensor<CType> &cGlobal,
AscendC::LocalTensor<float> &c1Local, MmadParams &mmadParams,
uint64_t scalarX2Scale)
{
AscendC::FixpipeParamsC310 intriParams;
intriParams.nSize = mmadParams.n;
intriParams.mSize = mmadParams.m;
intriParams.dstStride = n_;
intriParams.srcStride = Cgmct::Gemm::Align(mmadParams.m, AscendC::BLOCK_CUBE);
intriParams.quantPre = QuantMode_t::QF322F32_PRE;
intriParams.deqScalar = scalarX2Scale;
intriParams.unitFlag = FINAL_ACCUMULATION;
intriParams.params = {1, 1, 1};
AscendC::Fixpipe(cGlobal, c1Local, intriParams);
}
__aicore__ inline void UpdateKAL1(TileL1L0Param &tileL1L0Param, uint64_t offsetKAL1)
{
tileL1L0Param.curGmAKL1 = (offsetKAL1 + kAL1_ > k_) ? (k_ - offsetKAL1) : kAL1_;
if constexpr (transA) {
tileL1L0Param.curPadAKL1 = Cgmct::Gemm::CeilAlign(tileL1L0Param.curGmAKL1, AscendC::BLOCK_CUBE);
} else {
tileL1L0Param.curPadAKL1 = Cgmct::Gemm::CeilAlign(tileL1L0Param.curGmAKL1, C0_SIZE);
}
}
__aicore__ inline void UpdateKBL1(TileL1L0Param &tileL1L0Param, uint64_t offsetKBL1)
{
tileL1L0Param.curGmBKL1 = (offsetKBL1 + kBL1_ > k_) ? (k_ - offsetKBL1) : kBL1_;
if constexpr (transB) {
tileL1L0Param.curPadBKL1 = Cgmct::Gemm::CeilAlign(tileL1L0Param.curGmBKL1, C0_SIZE);
} else {
tileL1L0Param.curPadBKL1 = Cgmct::Gemm::CeilAlign(tileL1L0Param.curGmBKL1, AscendC::BLOCK_CUBE);
}
}
__aicore__ inline void UpdateKL0(TileL1L0Param &tileL1L0Param, uint64_t kL0Offset, uint64_t minGmKL1)
{
if (kL0Offset + baseK_ > minGmKL1) {
tileL1L0Param.curKL0 = minGmKL1 - kL0Offset;
} else {
tileL1L0Param.curKL0 = baseK_;
}
}
__aicore__ inline void GetAlignMN(TileL1L0Param &tileL1L0Param)
{
if constexpr (transA) {
tileL1L0Param.curAlignM = Cgmct::Gemm::CeilAlign(tileL1L0Param.curM, C0_SIZE);
} else {
tileL1L0Param.curAlignM = Cgmct::Gemm::CeilAlign(tileL1L0Param.curM, AscendC::BLOCK_CUBE);
}
if constexpr (!transB) {
tileL1L0Param.curAlignN = Cgmct::Gemm::CeilAlign(tileL1L0Param.curN, C0_SIZE);
} else {
tileL1L0Param.curAlignN = Cgmct::Gemm::CeilAlign(tileL1L0Param.curN, AscendC::BLOCK_CUBE);
}
}
__aicore__ inline void CopyAInL1(const AscendC::GlobalTensor<AType> &aGlobal, const TileL1L0Param &tileL1L0Param,
uint64_t l1BufId, uint64_t kL1Offset)
{
uint64_t offsetA = transA ? kL1Offset * m_ : kL1Offset;
CopyInA1(aGlobal[offsetA], aL1Local_[l1BufferAOffset_[l1BufId]], tileL1L0Param);
}
__aicore__ inline void CopyBInL1(const AscendC::GlobalTensor<BType> &bGlobal, const TileL1L0Param &tileL1L0Param,
uint64_t l1BufId, uint64_t kL1Offset)
{
uint64_t offsetB = transB ? kL1Offset : kL1Offset * n_;
CopyInB1(bGlobal[offsetB], bL1Local_[l1BufferBOffset_[l1BufId]], tileL1L0Param);
}
__aicore__ inline void Iterate(TileL1L0Param &tileL1L0Param, AscendC::MmadParams &mmadParams, uint64_t kL1Offset,
uint64_t al1BufId, uint64_t bl1BufId, uint64_t l0cOffset, uint64_t kaL1Offset,
uint64_t kbL1Offset)
{
uint64_t minGmKL1 = Cgmct::Gemm::Min(tileL1L0Param.curGmBKL1, tileL1L0Param.curGmAKL1);
for (uint64_t kL0Offset = kL1Offset; kL0Offset < Min(kL1Offset + minKL1_, k_); kL0Offset += baseK_) {
UpdateKL0(tileL1L0Param, kL0Offset - kL1Offset, minGmKL1);
uint64_t kL0L1Off = kL0Offset - kL1Offset;
uint64_t l0Offset = HALF_L0_SIZE * (l0PingPong_ & 0x1);
AscendC::WaitFlag<AscendC::HardEvent::M_MTE1>(l0PingPong_ & 0x1);
CopyInL0A(l0aLocal_[l0Offset], aL1Local_[l1BufferAOffset_[al1BufId]], tileL1L0Param, kL0L1Off + kaL1Offset);
CopyInL0B(l0bLocal_[l0Offset], bL1Local_[l1BufferBOffset_[bl1BufId]], tileL1L0Param, kL0L1Off + kbL1Offset);
AscendC::SetFlag<AscendC::HardEvent::MTE1_M>(l0PingPong_ & 0x1);
AscendC::WaitFlag<AscendC::HardEvent::MTE1_M>(l0PingPong_ & 0x1);
mmadParams.k = tileL1L0Param.curKL0;
mmadParams.unitFlag = (kL0Offset + baseK_ >= k_) ? FINAL_ACCUMULATION : NON_FINAL_ACCUMULATION;
mmadParams.cmatrixInitVal = (kL0Offset == 0 && !isBias_);
Mmad(mmadParams, l0cOffset, l0Offset, baseN_ * biasBufId_, NeedBias(kL0Offset));
AscendC::SetFlag<AscendC::HardEvent::M_MTE1>(l0PingPong_ & 0x1);
l0PingPong_++;
}
}
__aicore__ inline void IterBL1AL1(TileL1L0Param &tileL1L0Param, AscendC::MmadParams &mmadParams, uint64_t l0cOffset,
const AscendC::GlobalTensor<AType> &aGlobal,
const AscendC::GlobalTensor<BType> &bGlobal,
const AscendC::GlobalTensor<BiasType> &biasGlobal)
{
for (uint64_t kOuter = 0; kOuter < k_; kOuter += kBL1_) {
uint64_t bL1BufId = bL1LoopCnt_ & (l1BufNum_ - 1);
AscendC::WaitFlag<AscendC::HardEvent::MTE1_MTE2>(bL1BufId);
UpdateKBL1(tileL1L0Param, kOuter);
CopyBInL1(bGlobal, tileL1L0Param, bL1BufId, kOuter);
for (uint64_t kInner = kOuter; kInner < Min(kOuter + kBL1_, k_); kInner += kAL1_) {
uint64_t aL1BufId = aL1LoopCnt_ & (l1BufNum_ - 1);
AscendC::WaitFlag<AscendC::HardEvent::MTE1_MTE2>(INPUT_BUFFER_FLAG_2 + aL1BufId);
UpdateKAL1(tileL1L0Param, kInner);
CopyAInL1(aGlobal, tileL1L0Param, aL1BufId, kInner);
AscendC::SetFlag<AscendC::HardEvent::MTE2_MTE1>(aL1BufId);
AscendC::WaitFlag<AscendC::HardEvent::MTE2_MTE1>(aL1BufId);
Iterate(tileL1L0Param, mmadParams, kInner, aL1BufId, bL1BufId, l0cOffset, 0, kInner - kOuter);
AscendC::SetFlag<AscendC::HardEvent::MTE1_MTE2>(INPUT_BUFFER_FLAG_2 + aL1BufId);
aL1LoopCnt_++;
}
AscendC::SetFlag<AscendC::HardEvent::MTE1_MTE2>(bL1BufId);
bL1LoopCnt_++;
}
}
__aicore__ inline void IterAL1BL1(TileL1L0Param &tileL1L0Param, AscendC::MmadParams &mmadParams, uint64_t l0cOffset,
const AscendC::GlobalTensor<AType> &aGlobal,
const AscendC::GlobalTensor<BType> &bGlobal,
const AscendC::GlobalTensor<BiasType> &biasGlobal)
{
for (uint64_t kOuter = 0; kOuter < k_; kOuter += kAL1_) {
uint64_t aL1BufId = aL1LoopCnt_ & (l1BufNum_ - 1);
AscendC::WaitFlag<AscendC::HardEvent::MTE1_MTE2>(aL1BufId);
UpdateKAL1(tileL1L0Param, kOuter);
CopyAInL1(aGlobal, tileL1L0Param, aL1BufId, kOuter);
for (uint64_t kInner = kOuter; kInner < Min(kOuter + kAL1_, k_); kInner += kBL1_) {
uint64_t bL1BufId = bL1LoopCnt_ & (l1BufNum_ - 1);
AscendC::WaitFlag<AscendC::HardEvent::MTE1_MTE2>(INPUT_BUFFER_FLAG_2 + bL1BufId);
UpdateKBL1(tileL1L0Param, kInner);
CopyBInL1(bGlobal, tileL1L0Param, bL1BufId, kInner);
AscendC::SetFlag<AscendC::HardEvent::MTE2_MTE1>(bL1BufId);
AscendC::WaitFlag<AscendC::HardEvent::MTE2_MTE1>(bL1BufId);
Iterate(tileL1L0Param, mmadParams, kInner, aL1BufId, bL1BufId, l0cOffset, kInner - kOuter, 0);
AscendC::SetFlag<AscendC::HardEvent::MTE1_MTE2>(INPUT_BUFFER_FLAG_2 + bL1BufId);
bL1LoopCnt_++;
}
AscendC::SetFlag<AscendC::HardEvent::MTE1_MTE2>(aL1BufId);
aL1LoopCnt_++;
}
}
__aicore__ inline void run(const AscendC::GlobalTensor<AType> &aGlobal, const AscendC::GlobalTensor<BType> &bGlobal,
uint64_t scalarScale, const AscendC::GlobalTensor<BiasType> &biasGlobal,
const AscendC::GlobalTensor<CType> &cGlobal, const BlockShape &singleShape)
{
TileL1L0Param tileL1L0Param;
tileL1L0Param.curM = Get<IDX_M_TILE_IDX>(singleShape);
tileL1L0Param.curN = Get<IDX_N_TILE_IDX>(singleShape);
GetAlignMN(tileL1L0Param);
AscendC::MmadParams mmadParams;
mmadParams.m = tileL1L0Param.curM;
mmadParams.n = tileL1L0Param.curN;
mmadParams.disableGemv = true;
uint64_t l0cOffset = (l0cPingPong_ & 1) * HALF_L0C_SIZE;
if (orderAL1BL1_) {
IterAL1BL1(tileL1L0Param, mmadParams, l0cOffset, aGlobal, bGlobal, biasGlobal);
} else {
IterBL1AL1(tileL1L0Param, mmadParams, l0cOffset, aGlobal, bGlobal, biasGlobal);
}
AscendC::LocalTensor<float> c1Local = c1Local_[l0cOffset];
CopyFixpipeScalar(cGlobal, c1Local, mmadParams, scalarScale);
if (enableL0cPingPong_) {
l0cPingPong_++;
}
}
__aicore__ inline void operator()(const AscendC::GlobalTensor<AType> &aGlobal,
const AscendC::GlobalTensor<BType> &bGlobal, uint64_t scalarScale,
const AscendC::GlobalTensor<CType> &cGlobal, const BlockShape &singleShape)
{
AscendC::GlobalTensor<BiasType> biasGlobal;
run(aGlobal, bGlobal, scalarScale, biasGlobal, cGlobal, singleShape);
}
__aicore__ inline void operator()(const AscendC::GlobalTensor<AType> &aGlobal,
const AscendC::GlobalTensor<BType> &bGlobal, uint64_t scalarScale,
const AscendC::GlobalTensor<BiasType> &biasGlobal,
const AscendC::GlobalTensor<CType> &cGlobal, const BlockShape &singleShape)
{
run(aGlobal, bGlobal, scalarScale, biasGlobal, cGlobal, singleShape);
}
private:
__aicore__ inline bool NeedBias(uint64_t kL0Offset) { return isBias_ && kL0Offset == 0; }
__aicore__ inline void Mmad(AscendC::MmadParams &mmadParams, uint64_t l0cOffset, uint64_t l0abOffset,
uint64_t biasOffset, bool needBias)
{
mmadParams.cmatrixSource = needBias;
if (needBias) {
AscendC::Mmad(c1Local_[l0cOffset], l0aLocal_[l0abOffset], l0bLocal_[l0abOffset], biasBt_[biasOffset],
mmadParams);
} else {
AscendC::Mmad(c1Local_[l0cOffset], l0aLocal_[l0abOffset], l0bLocal_[l0abOffset], mmadParams);
}
}
private:
uint16_t biasBufId_ = 0;
uint64_t biasL1OneBuffer_ = 0UL;
uint64_t aL1OneBuffer_ = 0UL;
uint64_t bL1OneBuffer_ = 0UL;
uint64_t l1BufferAOffset_[2] = {0UL};
uint64_t l1BufferBOffset_[2] = {0UL};
uint64_t l1BufferBiasOffset_[2] = {0UL};
uint64_t m_;
uint64_t n_;
uint64_t k_;
uint64_t l1BufNum_{1};
uint64_t kAL1_{1};
uint64_t kBL1_{1};
uint64_t minKL1_{1};
uint64_t baseM_{16};
uint64_t baseN_{16};
uint64_t baseK_{16};
uint64_t aL1LoopCnt_{0};
uint64_t bL1LoopCnt_{0};
uint64_t l0PingPong_{0};
uint64_t l0cPingPong_{0};
bool orderAL1BL1_{false};
bool isBias_{false};
bool enableL0cPingPong_{false};
AscendC::LocalTensor<AType> l0aLocal_{AscendC::TPosition::A2, 0, L0A_SIZE};
AscendC::LocalTensor<BType> l0bLocal_{AscendC::TPosition::B2, 0, L0B_SIZE};
AscendC::LocalTensor<float> c1Local_{AscendC::TPosition::CO1, 0, AscendC::TOTAL_L0C_SIZE};
AscendC::LocalTensor<BiasType> biasBt_{AscendC::TPosition::C2, 0, 4096};
AscendC::LocalTensor<AType> aL1Local_{AscendC::TPosition::A1, 0, AscendC::TOTAL_L1_SIZE};
AscendC::LocalTensor<BType> bL1Local_{AscendC::TPosition::A1, 0, AscendC::TOTAL_L1_SIZE};
AscendC::LocalTensor<BiasType> biasL1Local_{AscendC::TPosition::A1, 0, AscendC::TOTAL_L1_SIZE / sizeof(BiasType)};
};
}
}
}
#endif