* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file block_epilogue_pertile.h
* \brief
*/
#ifndef EPILOGUE_BLOCK_EPILOGUE_PERTILE_H
#define EPILOGUE_BLOCK_EPILOGUE_PERTILE_H
#if ASC_DEVKIT_MAJOR >= 9
#include "kernel_basic_intf.h"
#else
#include "kernel_operator.h"
#endif
#include "../utils/common_utils.h"
#include "../utils/grouped_matmul_constant.h"
#include "../utils/layout_utils.h"
#include "../utils/tensor_utils.h"
namespace Cgmct {
namespace Gemm {
namespace Block {
#define QGMM_BLOCK_EPILOGUE_CLASS_LOCAL_PARAMS \
template <typename L0TileShape_, typename DataTypeOut_, typename DataTypeIn_, typename DataTypeBias_, \
typename DataTypeX1Scale_, typename DataTypeX2Scale_, typename LayoutX1Scale_, typename LayoutX2Scale_>
#define QGMM_BLOCK_EPILOGUE_FUNC_LOCAL_PARAMS \
L0TileShape_, DataTypeOut_, DataTypeIn_, DataTypeBias_, DataTypeX1Scale_, DataTypeX2Scale_, LayoutX1Scale_, \
LayoutX2Scale_
using namespace Cgmct::Gemm::GroupedMatmul;
struct PerBlockUBParam {
bool CopyOutWithSplitN = false;
uint16_t ndNum;
uint64_t singleM;
uint64_t singleN;
uint64_t validM;
uint32_t validN[UB_SUB_BANK_NUM];
uint64_t offsetScaleM;
uint64_t offsetScaleN[UB_SUB_BANK_NUM];
uint64_t offsetY[UB_SUB_BANK_NUM];
};
namespace {
constexpr uint32_t Y_IDX = 0;
constexpr uint32_t X2SCALE_IDX = 1;
constexpr uint32_t X1SCALE_IDX = 2;
constexpr uint32_t BIAS_IDX = 3;
}
QGMM_BLOCK_EPILOGUE_CLASS_LOCAL_PARAMS
class BlockEpiloguePerTile {
public:
__aicore__ inline BlockEpiloguePerTile()
{
if ASCEND_IS_AIV {
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(0);
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(1);
if constexpr (!AscendC::IsSameType<CType, YType>::value) {
AscendC::SetFlag<AscendC::HardEvent::MTE3_V>(0);
AscendC::SetFlag<AscendC::HardEvent::MTE3_V>(1);
}
}
}
__aicore__ inline ~BlockEpiloguePerTile()
{
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(0);
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(1);
if ASCEND_IS_AIV {
if constexpr (!AscendC::IsSameType<CType, YType>::value) {
AscendC::WaitFlag<AscendC::HardEvent::MTE3_V>(0);
AscendC::WaitFlag<AscendC::HardEvent::MTE3_V>(1);
}
}
}
struct Arguments {
GM_ADDR outGmAddr{nullptr};
GM_ADDR x2ScaleGmAddr{nullptr};
GM_ADDR x1ScaleGmAddr{nullptr};
GM_ADDR biasGmAddr{nullptr};
uint32_t baseM;
uint32_t baseN;
uint32_t baseK;
uint32_t groupSizeM = 1U;
uint32_t groupSizeN = 128U;
uint32_t groupSizeK = 128U;
Arguments() = default;
};
using Params = Arguments;
using YType = DataTypeOut_;
using CType = DataTypeIn_;
using BiasType = DataTypeBias_;
using X2ScaleType = DataTypeX2Scale_;
using X1ScaleType = DataTypeX1Scale_;
using LayoutX1Scale = LayoutX1Scale_;
using LayoutX2Scale = LayoutX2Scale_;
using TupleShape = AscendC::Shape<int64_t, int64_t, int64_t>;
using BlockCoord = AscendC::Coord<int64_t, int64_t, int64_t, int64_t>;
static constexpr bool transA = TagToTrans<LayoutX1Scale>::value;
static constexpr bool transB = TagToTrans<LayoutX2Scale>::value;
public:
__aicore__ inline void Init(const Params* params);
__aicore__ inline void operator()(const TupleShape& actualSingleShape, const BlockCoord& blockCoord);
__aicore__ inline void UpdateGlobalAddr(const BlockCoord& baseOffset);
__aicore__ inline void UpdateParamsForNextProblem(const TupleShape& problemShape);
__aicore__ inline auto GetL0c2UbPingTensor();
__aicore__ inline auto GetL0c2UbPongTensor();
private:
__aicore__ inline void ProcessAivSingleKPerTile(int64_t x1ScaleOffset,
__gm__ X2ScaleType* x2ScaleAddr[UB_SUB_BANK_NUM]);
__aicore__ inline void ProcessAivSingleKPerBlock(int64_t x1ScaleOffset,
__gm__ X2ScaleType* x2ScaleAddr[UB_SUB_BANK_NUM]);
template <class T>
__aicore__ inline __ubuf__ T* CopyInX1Scale(uint64_t srcOffset, uint64_t m, uint64_t k);
template <class T>
__aicore__ inline T CopyInX1ScalePerblock(__gm__ T* src, uint64_t offset);
template <class T>
__aicore__ inline void CopyInX2Scale(T x2Scale[UB_SUB_BANK_NUM], __gm__ T* src[UB_SUB_BANK_NUM], uint64_t offset);
__aicore__ inline int64_t CalcX1OffsetPerGroup();
__aicore__ inline void CalcX2OffsetPerGroup(int64_t x2ScaleOffset[UB_SUB_BANK_NUM]);
template <class T>
__aicore__ inline __ubuf__ T* GetX1ScaleUbAddrPerGroup(int64_t x1ScaleOffset, uint64_t kOffset, uint64_t kElem);
template <bool isFirstKLoop, uint32_t ndNum>
__aicore__ inline void AivPerTensor(__ubuf__ CType* dst, __ubuf__ CType* l0cOut, __ubuf__ X1ScaleType* x1Scale,
uint16_t mSize, uint32_t nSize0, uint32_t nSize1, uint16_t kSize,
X2ScaleType x2Scale0, X2ScaleType x2Scale1, uint64_t x1ScaleKIdxInCache);
template <bool isFirstKLoop, uint32_t ndNum>
__aicore__ inline void AivPerTensor(__ubuf__ CType* dst, __ubuf__ CType* l0cOut, X1ScaleType x1Scale,
uint16_t mSize, uint32_t nSize0, uint32_t nSize1, X2ScaleType x2Scale0,
X2ScaleType x2Scale1);
__aicore__ inline void AivPostProcess(const AscendC::LocalTensor<CType>& mmAddUb);
__aicore__ inline void CopyOut(const AscendC::LocalTensor<YType>& ubRes, uint16_t eventId, uint16_t blkCount,
uint32_t blkLen, uint32_t srcStride, uint32_t dstStride, uint64_t yOffset);
__aicore__ inline void CastAndCopyOut(const AscendC::LocalTensor<CType>& mmAddUb);
__aicore__ inline void UpdatePerBlockUBValidMN();
__aicore__ inline void UpdatePerBlockUBParam();
__aicore__ inline void WaitForCube(uint16_t crossPingPongID)
{
AscendC::CrossCoreWaitFlag<GMM_AIC_SYNC_AIV_MODE, PIPE_V>(GMM_AIV_SYNC_AIC_FLAG + crossPingPongID);
}
__aicore__ inline void NotifyCube(uint16_t crossPingPongID)
{
AscendC::CrossCoreSetFlag<GMM_AIC_SYNC_AIV_MODE, PIPE_V>(GMM_AIC_SYNC_AIV_FLAG + crossPingPongID);
}
AscendC::GlobalTensor<YType> cGlobal_;
AscendC::GlobalTensor<X1ScaleType> x1ScaleGlobal_;
__gm__ X1ScaleType* x1ScaleGlobalPerblock_;
__gm__ X2ScaleType* x2ScaleGlobal_;
AscendC::LocalTensor<CType> mmResPing_;
AscendC::LocalTensor<CType> mmResPong_;
AscendC::LocalTensor<YType> ubResPing_;
AscendC::LocalTensor<YType> ubResPong_;
AscendC::LocalTensor<CType> mmAddUb_;
AscendC::LocalTensor<X1ScaleType> x1ScaleUbPing_;
AscendC::LocalTensor<X1ScaleType> x1ScaleUbPong_;
private:
const Params* params_;
PerBlockUBParam ubParams_;
TupleShape problemShape_{};
TupleShape actualSingleShape_{};
BlockCoord baseOffset_{0, 0, 0, 0};
BlockCoord blockCoord_{0, 0, 0, 0};
uint64_t scaleM_ = 0;
uint64_t scaleN_ = 0;
uint64_t scaleK_ = 0;
uint32_t subBlockIdx_;
uint16_t crossPingPongID_ = 0;
uint16_t x1ScalePingPongID_ = 0;
bool isPertile_ = false;
};
QGMM_BLOCK_EPILOGUE_CLASS_LOCAL_PARAMS
__aicore__ inline void BlockEpiloguePerTile<QGMM_BLOCK_EPILOGUE_FUNC_LOCAL_PARAMS>::Init(const Params* params)
{
if ASCEND_IS_AIC {
return;
}
params_ = params;
subBlockIdx_ = AscendC::GetSubBlockIdx();
(void)GetL0c2UbPingTensor();
(void)GetL0c2UbPongTensor();
constexpr uint32_t elems = UB_TWO_BANK_ELEMS_B32 * PER_BLOCK_SIZE;
constexpr uint32_t addUbOffset = elems * UB_SUB_BANK_NUM * sizeof(CType);
mmAddUb_ = AscendC::LocalTensor<CType>(AscendC::TPosition::VECCALC, addUbOffset, elems);
constexpr uint32_t afterAddOffset = addUbOffset + elems * sizeof(CType);
if constexpr (!AscendC::IsSameType<CType, YType>::value) {
ubResPing_ = AscendC::LocalTensor<YType>(AscendC::TPosition::VECCALC, afterAddOffset, elems);
ubResPong_ = ubResPing_[elems / GMM_BUFFER_NUM];
}
isPertile_ = params_->groupSizeM == 1;
if (isPertile_) {
constexpr uint32_t x1ScaleUbOffset =
(AscendC::IsSameType<CType, YType>::value) ? afterAddOffset : afterAddOffset + elems * sizeof(YType);
x1ScaleUbPing_ = AscendC::LocalTensor<X1ScaleType>(AscendC::TPosition::VECCALC, x1ScaleUbOffset,
PER_BLOCK_SIZE * GMM_MAX_STEP_SCALEA_K);
x1ScaleUbPong_ = AscendC::LocalTensor<X1ScaleType>(
AscendC::TPosition::VECCALC, x1ScaleUbOffset + PER_BLOCK_SIZE * GMM_MAX_STEP_SCALEA_K * sizeof(X1ScaleType),
PER_BLOCK_SIZE * GMM_MAX_STEP_SCALEA_K);
}
}
QGMM_BLOCK_EPILOGUE_CLASS_LOCAL_PARAMS
__aicore__ inline auto BlockEpiloguePerTile<QGMM_BLOCK_EPILOGUE_FUNC_LOCAL_PARAMS>::GetL0c2UbPingTensor()
{
constexpr uint32_t elems = UB_TWO_BANK_ELEMS_B32 * PER_BLOCK_SIZE;
mmResPing_ = AscendC::LocalTensor<CType>(AscendC::TPosition::VECCALC, 0, elems * UB_SUB_BANK_NUM);
return mmResPing_;
}
QGMM_BLOCK_EPILOGUE_CLASS_LOCAL_PARAMS
__aicore__ inline auto BlockEpiloguePerTile<QGMM_BLOCK_EPILOGUE_FUNC_LOCAL_PARAMS>::GetL0c2UbPongTensor()
{
constexpr uint32_t elems = UB_TWO_BANK_ELEMS_B32 * PER_BLOCK_SIZE;
mmResPong_ = AscendC::LocalTensor<CType>(AscendC::TPosition::VECCALC, UB_SUB_BANK_ELEMS_B32 * sizeof(CType),
elems * UB_SUB_BANK_NUM - UB_SUB_BANK_ELEMS_B32);
return mmResPong_;
}
QGMM_BLOCK_EPILOGUE_CLASS_LOCAL_PARAMS
__aicore__ inline void
BlockEpiloguePerTile<QGMM_BLOCK_EPILOGUE_FUNC_LOCAL_PARAMS>::UpdateParamsForNextProblem(const TupleShape& problemShape)
{
problemShape_ = problemShape;
scaleM_ = Cgmct::Gemm::CeilDiv(Get<MNK_M>(problemShape_), params_->groupSizeM);
scaleN_ = Cgmct::Gemm::CeilDiv(Get<MNK_N>(problemShape_), params_->groupSizeN);
scaleK_ = Cgmct::Gemm::CeilDiv(Get<MNK_K>(problemShape_), params_->groupSizeK);
}
QGMM_BLOCK_EPILOGUE_CLASS_LOCAL_PARAMS
__aicore__ inline void
BlockEpiloguePerTile<QGMM_BLOCK_EPILOGUE_FUNC_LOCAL_PARAMS>::UpdateGlobalAddr(const BlockCoord& baseOffset)
{
if ASCEND_IS_AIV {
x1ScaleGlobal_.SetGlobalBuffer((__gm__ X1ScaleType*)params_->x1ScaleGmAddr + Get<X1SCALE_IDX>(baseOffset));
x1ScaleGlobalPerblock_ = (__gm__ X1ScaleType*)params_->x1ScaleGmAddr + Get<X1SCALE_IDX>(baseOffset);
x2ScaleGlobal_ = (__gm__ X2ScaleType*)params_->x2ScaleGmAddr + Get<X2SCALE_IDX>(baseOffset);
cGlobal_.SetGlobalBuffer((__gm__ YType*)params_->outGmAddr + Get<Y_IDX>(baseOffset));
}
}
QGMM_BLOCK_EPILOGUE_CLASS_LOCAL_PARAMS
__aicore__ inline int64_t BlockEpiloguePerTile<QGMM_BLOCK_EPILOGUE_FUNC_LOCAL_PARAMS>::CalcX1OffsetPerGroup()
{
int64_t x1ScaleOffset = Get<X1SCALE_IDX>(blockCoord_);
if (subBlockIdx_ == 1) {
if constexpr (transA) {
x1ScaleOffset += ubParams_.offsetScaleM;
} else {
x1ScaleOffset += (ubParams_.offsetScaleM * scaleK_);
}
}
return x1ScaleOffset;
}
QGMM_BLOCK_EPILOGUE_CLASS_LOCAL_PARAMS
__aicore__ inline void BlockEpiloguePerTile<QGMM_BLOCK_EPILOGUE_FUNC_LOCAL_PARAMS>::CalcX2OffsetPerGroup(
int64_t x2ScaleOffset[UB_SUB_BANK_NUM])
{
if constexpr (transB) {
x2ScaleOffset[0] = Get<X2SCALE_IDX>(blockCoord_) + (ubParams_.offsetScaleN[0] * scaleK_);
x2ScaleOffset[1] = Get<X2SCALE_IDX>(blockCoord_) + (ubParams_.offsetScaleN[1] * scaleK_);
} else {
x2ScaleOffset[0] = Get<X2SCALE_IDX>(blockCoord_) + ubParams_.offsetScaleN[0];
x2ScaleOffset[1] = Get<X2SCALE_IDX>(blockCoord_) + ubParams_.offsetScaleN[1];
}
}
QGMM_BLOCK_EPILOGUE_CLASS_LOCAL_PARAMS
template <class T>
__aicore__ inline __ubuf__ T*
BlockEpiloguePerTile<QGMM_BLOCK_EPILOGUE_FUNC_LOCAL_PARAMS>::CopyInX1Scale(uint64_t srcOffset, uint64_t m, uint64_t k)
{
AscendC::DataCopyExtParams x1ScaleGm2UbParams;
AscendC::DataCopyPadExtParams<X1ScaleType> padParams;
if constexpr (transA) {
x1ScaleGm2UbParams.blockCount = k;
x1ScaleGm2UbParams.blockLen = m * sizeof(T);
x1ScaleGm2UbParams.srcStride = (scaleM_ - m) * sizeof(T);
} else {
x1ScaleGm2UbParams.blockCount = m;
x1ScaleGm2UbParams.blockLen = k * sizeof(T);
x1ScaleGm2UbParams.srcStride = (scaleK_ - k) * sizeof(T);
}
auto x1ScaleUb = x1ScalePingPongID_ == 0 ? &x1ScaleUbPing_ : &x1ScaleUbPong_;
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(x1ScalePingPongID_);
AscendC::DataCopyPad(*x1ScaleUb, x1ScaleGlobal_[srcOffset], x1ScaleGm2UbParams, padParams);
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(x1ScalePingPongID_);
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(x1ScalePingPongID_);
return reinterpret_cast<__ubuf__ T*>(x1ScaleUb->GetPhyAddr());
}
QGMM_BLOCK_EPILOGUE_CLASS_LOCAL_PARAMS
template <class T>
__aicore__ inline T BlockEpiloguePerTile<QGMM_BLOCK_EPILOGUE_FUNC_LOCAL_PARAMS>::CopyInX1ScalePerblock(__gm__ T* src,
uint64_t offset)
{
if constexpr (transA) {
return src[offset * scaleM_];
} else {
return src[offset];
}
}
QGMM_BLOCK_EPILOGUE_CLASS_LOCAL_PARAMS
template <class T>
__aicore__ inline void BlockEpiloguePerTile<QGMM_BLOCK_EPILOGUE_FUNC_LOCAL_PARAMS>::CopyInX2Scale(
T x2Scale[UB_SUB_BANK_NUM], __gm__ T* src[UB_SUB_BANK_NUM], uint64_t offset)
{
if constexpr (transB) {
x2Scale[0] = src[0][offset];
x2Scale[1] = src[1][offset];
} else {
x2Scale[0] = src[0][offset * scaleN_];
x2Scale[1] = src[1][offset * scaleN_];
}
}
QGMM_BLOCK_EPILOGUE_CLASS_LOCAL_PARAMS
__aicore__ inline void BlockEpiloguePerTile<QGMM_BLOCK_EPILOGUE_FUNC_LOCAL_PARAMS>::UpdatePerBlockUBValidMN()
{
int64_t actualN = Get<MNK_N>(actualSingleShape_);
if (ubParams_.CopyOutWithSplitN) {
ubParams_.validM = ubParams_.singleM;
uint64_t subBlockIdxOffset = AscendC::GetSubBlockIdx() * ubParams_.singleN;
uint64_t ndNumN = 2 * ubParams_.singleN + subBlockIdxOffset;
ubParams_.validN[0] = actualN < subBlockIdxOffset ? 0 : Min(ubParams_.singleN, actualN - subBlockIdxOffset);
ubParams_.validN[1] = actualN < ndNumN ? 0 : Min(ubParams_.singleN, actualN - ndNumN);
} else {
if (AscendC::GetSubBlockIdx() == 0) {
ubParams_.validM = ubParams_.singleM;
} else {
ubParams_.validM = Get<MNK_M>(actualSingleShape_) - ubParams_.singleM;
}
ubParams_.validN[0] = Min(ubParams_.singleN, static_cast<uint64_t>(actualN));
ubParams_.validN[1] = actualN < ubParams_.singleN ? 0 : Min(ubParams_.singleN, actualN - ubParams_.singleN);
}
}
QGMM_BLOCK_EPILOGUE_CLASS_LOCAL_PARAMS
__aicore__ inline void BlockEpiloguePerTile<QGMM_BLOCK_EPILOGUE_FUNC_LOCAL_PARAMS>::UpdatePerBlockUBParam()
{
ubParams_.CopyOutWithSplitN =
Get<MNK_N>(actualSingleShape_) > params_->groupSizeN || Get<MNK_M>(actualSingleShape_) == 1;
uint32_t fixpipeN = 0;
if (ubParams_.CopyOutWithSplitN) {
ubParams_.ndNum = Get<MNK_N>(actualSingleShape_) > UB_TWO_BANK_ELEMS_B32 ? 2 : 1;
int64_t alignedNBase =
Get<MNK_N>(actualSingleShape_) > PER_BLOCK_SIZE ? PER_BLOCK_SIZE : AscendC::ONE_BLK_SIZE * ubParams_.ndNum;
fixpipeN = Align(Get<MNK_N>(actualSingleShape_), static_cast<uint64_t>(alignedNBase)) / ubParams_.ndNum;
ubParams_.singleN = fixpipeN / static_cast<uint32_t>(AscendC::GetTaskRation());
ubParams_.singleM = Get<MNK_M>(actualSingleShape_);
} else {
ubParams_.ndNum = Get<MNK_N>(actualSingleShape_) > UB_SUB_BANK_ELEMS_B32 ? 2 : 1;
fixpipeN = Align(Get<MNK_N>(actualSingleShape_), static_cast<uint64_t>(AscendC::BLOCK_CUBE) * ubParams_.ndNum) /
ubParams_.ndNum;
ubParams_.singleN = fixpipeN;
ubParams_.singleM = CeilDiv(Get<MNK_M>(actualSingleShape_), AscendC::GetTaskRation());
}
UpdatePerBlockUBValidMN();
int64_t offsetM = 0;
int64_t offsetN0 = 0;
int64_t offsetN1 = 0;
if (ubParams_.CopyOutWithSplitN) {
offsetN0 = ubParams_.validN[0] == 0 ? 0 : AscendC::GetSubBlockIdx() * ubParams_.singleN;
offsetN1 = ubParams_.validN[1] == 0 ? offsetN0 : offsetN0 + UB_SUB_BANK_NUM * ubParams_.singleN;
} else {
if (AscendC::GetSubBlockIdx() == 1) {
offsetM += ubParams_.singleM;
}
offsetN1 = ubParams_.validN[1] == 0 ? 0 : ubParams_.singleN;
}
ubParams_.offsetScaleM = offsetM / params_->groupSizeM;
ubParams_.offsetScaleN[0] = offsetN0 / params_->groupSizeN;
ubParams_.offsetScaleN[1] = offsetN1 / params_->groupSizeN;
ubParams_.offsetY[0] = Get<Y_IDX>(blockCoord_) + offsetM * Get<MNK_N>(problemShape_) + offsetN0;
ubParams_.offsetY[1] = Get<Y_IDX>(blockCoord_) + offsetM * Get<MNK_N>(problemShape_) + offsetN1;
}
QGMM_BLOCK_EPILOGUE_CLASS_LOCAL_PARAMS
__aicore__ inline void
BlockEpiloguePerTile<QGMM_BLOCK_EPILOGUE_FUNC_LOCAL_PARAMS>::operator()(const TupleShape& actualSingleShape,
const BlockCoord& blockCoord)
{
actualSingleShape_ = actualSingleShape;
blockCoord_ = blockCoord;
UpdatePerBlockUBParam();
int64_t x1ScaleOffset = CalcX1OffsetPerGroup();
int64_t x2ScaleOffset[UB_SUB_BANK_NUM] = {0};
CalcX2OffsetPerGroup(x2ScaleOffset);
__gm__ X2ScaleType* x2ScaleAddr[UB_SUB_BANK_NUM] = {x2ScaleGlobal_ + x2ScaleOffset[0],
x2ScaleGlobal_ + x2ScaleOffset[1]};
if (isPertile_) {
ProcessAivSingleKPerTile(x1ScaleOffset, x2ScaleAddr);
} else {
ProcessAivSingleKPerBlock(x1ScaleOffset, x2ScaleAddr);
}
}
QGMM_BLOCK_EPILOGUE_CLASS_LOCAL_PARAMS
template <class T>
__aicore__ inline __ubuf__ T*
BlockEpiloguePerTile<QGMM_BLOCK_EPILOGUE_FUNC_LOCAL_PARAMS>::GetX1ScaleUbAddrPerGroup(int64_t x1ScaleOffset,
uint64_t kOffset, uint64_t kElem)
{
uint64_t scaleX1GmOffset;
if constexpr (transA) {
scaleX1GmOffset = x1ScaleOffset + kOffset * scaleM_;
} else {
scaleX1GmOffset = x1ScaleOffset + kOffset;
}
return CopyInX1Scale<X1ScaleType>(scaleX1GmOffset, ubParams_.validM, kElem);
}
QGMM_BLOCK_EPILOGUE_CLASS_LOCAL_PARAMS
__aicore__ inline void BlockEpiloguePerTile<QGMM_BLOCK_EPILOGUE_FUNC_LOCAL_PARAMS>::ProcessAivSingleKPerTile(
int64_t x1ScaleOffset, __gm__ X2ScaleType* x2ScaleAddr[UB_SUB_BANK_NUM])
{
auto mmAddUbAddr = reinterpret_cast<__ubuf__ CType*>(mmAddUb_.GetPhyAddr());
const uint16_t x1ScaleKElem = Min(GMM_MAX_STEP_SCALEA_K, scaleK_);
uint64_t kElem;
__ubuf__ X1ScaleType* x1ScaleUbAddr;
X2ScaleType x2Scale[UB_SUB_BANK_NUM];
for (uint64_t kb = 0, kOffset = 0; kb < Get<MNK_K>(problemShape_); kb += params_->baseK, kOffset++) {
CopyInX2Scale<X2ScaleType>(x2Scale, x2ScaleAddr, kOffset);
uint64_t x1ScaleKRem = kOffset % x1ScaleKElem;
if (x1ScaleKRem == 0) {
kElem = Min(static_cast<uint64_t>(x1ScaleKElem), scaleK_ - kOffset);
x1ScaleUbAddr = GetX1ScaleUbAddrPerGroup<X1ScaleType>(x1ScaleOffset, kOffset, kElem);
}
WaitForCube(crossPingPongID_);
auto mmUbInputAddr = crossPingPongID_ == 0 ? reinterpret_cast<uint64_t>(mmResPing_.GetPhyAddr()) :
reinterpret_cast<uint64_t>(mmResPong_.GetPhyAddr());
if (kb == 0) {
if (ubParams_.ndNum == 1) {
AivPerTensor<true, 1U>((__ubuf__ CType*)mmAddUbAddr, (__ubuf__ CType*)mmUbInputAddr, x1ScaleUbAddr,
ubParams_.validM, ubParams_.validN[0], ubParams_.validN[1], kElem, x2Scale[0],
x2Scale[1], x1ScaleKRem);
} else {
AivPerTensor<true, 2U>((__ubuf__ CType*)mmAddUbAddr, (__ubuf__ CType*)mmUbInputAddr, x1ScaleUbAddr,
ubParams_.validM, ubParams_.validN[0], ubParams_.validN[1], kElem, x2Scale[0],
x2Scale[1], x1ScaleKRem);
}
} else {
if (ubParams_.ndNum == 1) {
AivPerTensor<false, 1U>((__ubuf__ CType*)mmAddUbAddr, (__ubuf__ CType*)mmUbInputAddr, x1ScaleUbAddr,
ubParams_.validM, ubParams_.validN[0], ubParams_.validN[1], kElem, x2Scale[0],
x2Scale[1], x1ScaleKRem);
} else {
AivPerTensor<false, 2U>((__ubuf__ CType*)mmAddUbAddr, (__ubuf__ CType*)mmUbInputAddr, x1ScaleUbAddr,
ubParams_.validM, ubParams_.validN[0], ubParams_.validN[1], kElem, x2Scale[0],
x2Scale[1], x1ScaleKRem);
}
}
if (x1ScaleKRem == x1ScaleKElem - 1 || kOffset == scaleK_ - 1) {
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(x1ScalePingPongID_);
x1ScalePingPongID_ = (x1ScalePingPongID_ + 1) & 1;
}
NotifyCube(crossPingPongID_);
crossPingPongID_ = (crossPingPongID_ + 1) & 1;
}
AivPostProcess(mmAddUb_);
}
QGMM_BLOCK_EPILOGUE_CLASS_LOCAL_PARAMS
__aicore__ inline void BlockEpiloguePerTile<QGMM_BLOCK_EPILOGUE_FUNC_LOCAL_PARAMS>::ProcessAivSingleKPerBlock(
int64_t x1ScaleOffset, __gm__ X2ScaleType* x2ScaleAddr[UB_SUB_BANK_NUM])
{
auto mmAddUbAddr = reinterpret_cast<__ubuf__ CType*>(mmAddUb_.GetPhyAddr());
auto x1ScaleAddr = x1ScaleGlobalPerblock_ + x1ScaleOffset;
X2ScaleType x2Scale[UB_SUB_BANK_NUM];
for (uint64_t kb = 0, kOffset = 0; kb < Get<MNK_K>(problemShape_); kb += params_->baseK, kOffset++) {
CopyInX2Scale<X2ScaleType>(x2Scale, x2ScaleAddr, kOffset);
X1ScaleType x1Scale = CopyInX1ScalePerblock(x1ScaleAddr, kOffset);
WaitForCube(crossPingPongID_);
auto mmUbInputAddr = crossPingPongID_ == 0 ? reinterpret_cast<uint64_t>(mmResPing_.GetPhyAddr()) :
reinterpret_cast<uint64_t>(mmResPong_.GetPhyAddr());
if (kb == 0) {
if (ubParams_.ndNum == 1) {
AivPerTensor<true, 1U>((__ubuf__ CType*)mmAddUbAddr, (__ubuf__ CType*)mmUbInputAddr, x1Scale,
ubParams_.validM, ubParams_.validN[0], ubParams_.validN[1], x2Scale[0],
x2Scale[1]);
} else {
AivPerTensor<true, 2U>((__ubuf__ CType*)mmAddUbAddr, (__ubuf__ CType*)mmUbInputAddr, x1Scale,
ubParams_.validM, ubParams_.validN[0], ubParams_.validN[1], x2Scale[0],
x2Scale[1]);
}
} else {
if (ubParams_.ndNum == 1) {
AivPerTensor<false, 1U>((__ubuf__ CType*)mmAddUbAddr, (__ubuf__ CType*)mmUbInputAddr, x1Scale,
ubParams_.validM, ubParams_.validN[0], ubParams_.validN[1], x2Scale[0],
x2Scale[1]);
} else {
AivPerTensor<false, 2U>((__ubuf__ CType*)mmAddUbAddr, (__ubuf__ CType*)mmUbInputAddr, x1Scale,
ubParams_.validM, ubParams_.validN[0], ubParams_.validN[1], x2Scale[0],
x2Scale[1]);
}
}
NotifyCube(crossPingPongID_);
crossPingPongID_ = (crossPingPongID_ + 1) & 1;
}
AivPostProcess(mmAddUb_);
}
QGMM_BLOCK_EPILOGUE_CLASS_LOCAL_PARAMS
template <bool isFirstKLoop, uint32_t ndNum>
__aicore__ inline void BlockEpiloguePerTile<QGMM_BLOCK_EPILOGUE_FUNC_LOCAL_PARAMS>::AivPerTensor(
__ubuf__ CType* dst, __ubuf__ CType* l0cOut, __ubuf__ X1ScaleType* x1Scale, uint16_t mSize, uint32_t nSize0,
uint32_t nSize1, uint16_t kSize, X2ScaleType x2Scale0, X2ScaleType x2Scale1, uint64_t x1ScaleKIdxInCache)
{
uint16_t alignM = Align(mSize, GMM_UB_ALIGN_SIZE / sizeof(X1ScaleType));
uint16_t alignK = Align(kSize, GMM_UB_ALIGN_SIZE / sizeof(X1ScaleType));
__VEC_SCOPE__
{
for (uint16_t mIdx = 0; mIdx < mSize; mIdx++) {
AscendC::MicroAPI::RegTensor<X1ScaleType> x1ScaleReg, muledScaleReg;
if constexpr (transA) {
AscendC::MicroAPI::DataCopy<X1ScaleType, AscendC::MicroAPI::LoadDist::DIST_BRC_B32>(
x1ScaleReg, x1Scale + x1ScaleKIdxInCache * alignM + mIdx);
} else {
AscendC::MicroAPI::DataCopy<X1ScaleType, AscendC::MicroAPI::LoadDist::DIST_BRC_B32>(
x1ScaleReg, x1Scale + mIdx * alignK + x1ScaleKIdxInCache);
}
uint32_t elementNum = nSize0;
AscendC::MicroAPI::MaskReg maskN = AscendC::MicroAPI::UpdateMask<CType>(elementNum);
AscendC::MicroAPI::RegTensor<CType> l0cOutReg;
AscendC::MicroAPI::RegTensor<CType> addReg;
AscendC::MicroAPI::RegTensor<CType> ResReg, mulScaleOutReg;
uint32_t offset = mIdx * UB_TWO_BANK_ELEMS_B32;
uint32_t l0cOutOffset = offset;
AscendC::MicroAPI::DataCopy(l0cOutReg, l0cOut + offset);
AscendC::MicroAPI::Muls(muledScaleReg, x1ScaleReg, x2Scale0, maskN);
AscendC::MicroAPI::Mul(mulScaleOutReg, l0cOutReg, muledScaleReg, maskN);
uint32_t dstUbOffset = offset;
if constexpr (isFirstKLoop) {
AscendC::MicroAPI::DataCopy<CType, AscendC::MicroAPI::StoreDist::DIST_NORM_B32>(dst + dstUbOffset,
mulScaleOutReg, maskN);
} else {
AscendC::MicroAPI::DataCopy(addReg, dst + dstUbOffset);
AscendC::MicroAPI::Add(ResReg, mulScaleOutReg, addReg, maskN);
AscendC::MicroAPI::DataCopy<CType, AscendC::MicroAPI::StoreDist::DIST_NORM_B32>(dst + dstUbOffset,
ResReg, maskN);
}
if constexpr (ndNum == 1) {
continue;
}
elementNum = nSize1;
maskN = AscendC::MicroAPI::UpdateMask<CType>(elementNum);
l0cOutOffset = offset + UB_TWO_BANK_ELEMS_B32 * PER_BLOCK_SIZE;
AscendC::MicroAPI::DataCopy(l0cOutReg, l0cOut + l0cOutOffset);
AscendC::MicroAPI::Muls(muledScaleReg, x1ScaleReg, x2Scale1, maskN);
AscendC::MicroAPI::Mul(mulScaleOutReg, l0cOutReg, muledScaleReg, maskN);
dstUbOffset = offset + nSize0;
if constexpr (isFirstKLoop) {
AscendC::MicroAPI::DataCopy<CType, AscendC::MicroAPI::StoreDist::DIST_NORM_B32>(dst + dstUbOffset,
mulScaleOutReg, maskN);
} else {
AscendC::MicroAPI::DataCopy(addReg, dst + dstUbOffset);
AscendC::MicroAPI::Add(ResReg, mulScaleOutReg, addReg, maskN);
AscendC::MicroAPI::DataCopy<CType, AscendC::MicroAPI::StoreDist::DIST_NORM_B32>(dst + dstUbOffset,
ResReg, maskN);
}
}
}
}
QGMM_BLOCK_EPILOGUE_CLASS_LOCAL_PARAMS
template <bool isFirstKLoop, uint32_t ndNum>
__aicore__ inline void BlockEpiloguePerTile<QGMM_BLOCK_EPILOGUE_FUNC_LOCAL_PARAMS>::AivPerTensor(
__ubuf__ CType* dst, __ubuf__ CType* l0cOut, X1ScaleType x1Scale, uint16_t mSize, uint32_t nSize0, uint32_t nSize1,
X2ScaleType x2Scale0, X2ScaleType x2Scale1)
{
__VEC_SCOPE__
{
for (uint16_t mIdx = 0; mIdx < mSize; mIdx++) {
X2ScaleType scaleMul = x1Scale * x2Scale0;
uint32_t elementNum = nSize0;
AscendC::MicroAPI::MaskReg maskN = AscendC::MicroAPI::UpdateMask<CType>(elementNum);
AscendC::MicroAPI::RegTensor<CType> l0cOutReg;
AscendC::MicroAPI::RegTensor<CType> addReg;
AscendC::MicroAPI::RegTensor<CType> ResReg, mulScaleOutReg;
uint32_t offset = mIdx * UB_TWO_BANK_ELEMS_B32;
uint32_t l0cOutOffset = offset;
AscendC::MicroAPI::DataCopy(l0cOutReg, l0cOut + offset);
AscendC::MicroAPI::Muls(mulScaleOutReg, l0cOutReg, scaleMul, maskN);
uint32_t dstUbOffset = offset;
if constexpr (isFirstKLoop) {
AscendC::MicroAPI::DataCopy<CType, AscendC::MicroAPI::StoreDist::DIST_NORM_B32>(dst + dstUbOffset,
mulScaleOutReg, maskN);
} else {
AscendC::MicroAPI::DataCopy(addReg, dst + dstUbOffset);
AscendC::MicroAPI::Add(ResReg, mulScaleOutReg, addReg, maskN);
AscendC::MicroAPI::DataCopy<CType, AscendC::MicroAPI::StoreDist::DIST_NORM_B32>(dst + dstUbOffset,
ResReg, maskN);
}
if constexpr (ndNum == 1) {
continue;
}
scaleMul = x1Scale * x2Scale1;
elementNum = nSize1;
maskN = AscendC::MicroAPI::UpdateMask<CType>(elementNum);
l0cOutOffset = offset + UB_TWO_BANK_ELEMS_B32 * PER_BLOCK_SIZE;
AscendC::MicroAPI::DataCopy(l0cOutReg, l0cOut + l0cOutOffset);
AscendC::MicroAPI::Muls(mulScaleOutReg, l0cOutReg, scaleMul, maskN);
dstUbOffset = offset + nSize0;
if constexpr (isFirstKLoop) {
AscendC::MicroAPI::DataCopy<CType, AscendC::MicroAPI::StoreDist::DIST_NORM_B32>(dst + dstUbOffset,
mulScaleOutReg, maskN);
} else {
AscendC::MicroAPI::DataCopy(addReg, dst + dstUbOffset);
AscendC::MicroAPI::Add(ResReg, mulScaleOutReg, addReg, maskN);
AscendC::MicroAPI::DataCopy<CType, AscendC::MicroAPI::StoreDist::DIST_NORM_B32>(dst + dstUbOffset,
ResReg, maskN);
}
}
}
}
QGMM_BLOCK_EPILOGUE_CLASS_LOCAL_PARAMS
__aicore__ inline void
BlockEpiloguePerTile<QGMM_BLOCK_EPILOGUE_FUNC_LOCAL_PARAMS>::AivPostProcess(const AscendC::LocalTensor<CType>& mmAddUb)
{
if (ubParams_.validM == 0) {
return;
}
if constexpr (AscendC::IsSameType<YType, CType>::value) {
AscendC::SetFlag<AscendC::HardEvent::V_MTE3>(0);
AscendC::WaitFlag<AscendC::HardEvent::V_MTE3>(0);
if (ubParams_.ndNum == 2 && !ubParams_.CopyOutWithSplitN) {
uint32_t sumN = ubParams_.validN[0] + ubParams_.validN[1];
CopyOut(mmAddUb, 0, ubParams_.validM, sumN, UB_TWO_BANK_ELEMS_B32 - sumN, Get<MNK_N>(problemShape_) - sumN,
ubParams_.offsetY[0]);
} else {
for (uint64_t ndIdx = 0; ndIdx < ubParams_.ndNum; ndIdx++) {
if (ubParams_.validN[ndIdx] > 0) {
CopyOut(mmAddUb[ndIdx * ubParams_.validN[0]], 0, ubParams_.validM, ubParams_.validN[ndIdx],
UB_TWO_BANK_ELEMS_B32 - ubParams_.validN[ndIdx],
Get<MNK_N>(problemShape_) - ubParams_.validN[ndIdx], ubParams_.offsetY[ndIdx]);
}
}
}
AscendC::SetFlag<AscendC::HardEvent::MTE3_V>(0);
AscendC::WaitFlag<AscendC::HardEvent::MTE3_V>(0);
} else {
AscendC::PipeBarrier<PIPE_V>();
CastAndCopyOut(mmAddUb);
}
}
QGMM_BLOCK_EPILOGUE_CLASS_LOCAL_PARAMS
__aicore__ inline void BlockEpiloguePerTile<QGMM_BLOCK_EPILOGUE_FUNC_LOCAL_PARAMS>::CopyOut(
const AscendC::LocalTensor<YType> &ubRes, uint16_t eventId, uint16_t blkCount, uint32_t blkLen, uint32_t srcStride,
uint32_t dstStride, uint64_t yOffset)
{
AscendC::DataCopyExtParams copyParams{blkCount, static_cast<uint32_t>(blkLen * sizeof(YType)),
static_cast<uint32_t>(srcStride * sizeof(YType) / AscendC::ONE_BLK_SIZE),
static_cast<uint32_t>(dstStride * sizeof(YType)), 0};
AscendC::DataCopyPad<YType>(cGlobal_[yOffset], ubRes, copyParams);
}
QGMM_BLOCK_EPILOGUE_CLASS_LOCAL_PARAMS
__aicore__ inline void
BlockEpiloguePerTile<QGMM_BLOCK_EPILOGUE_FUNC_LOCAL_PARAMS>::CastAndCopyOut(const AscendC::LocalTensor<CType> &mmAddUb)
{
if (ubParams_.ndNum == 2 && !ubParams_.CopyOutWithSplitN) {
uint32_t sumN = ubParams_.validN[0] + ubParams_.validN[1];
uint32_t mSizePing = CeilDiv(ubParams_.validM, static_cast<uint64_t>(GMM_BUFFER_NUM));
uint32_t mSize[GMM_BUFFER_NUM] = {mSizePing, static_cast<uint32_t>(ubParams_.validM - mSizePing)};
for (uint32_t mDbIdx = 0; mDbIdx < GMM_BUFFER_NUM; ++mDbIdx) {
if (mSize[mDbIdx] > 0 && sumN > 0) {
auto ubRes = mDbIdx == 0 ? &ubResPing_ : &ubResPong_;
AscendC::WaitFlag<AscendC::HardEvent::MTE3_V>(mDbIdx);
AscendC::Cast(*ubRes, mmAddUb[mDbIdx * mSizePing * UB_TWO_BANK_ELEMS_B32],
AscendC::RoundMode::CAST_RINT, mSize[mDbIdx] * UB_TWO_BANK_ELEMS_B32);
AscendC::SetFlag<AscendC::HardEvent::V_MTE3>(mDbIdx);
AscendC::WaitFlag<AscendC::HardEvent::V_MTE3>(mDbIdx);
CopyOut(*ubRes, mDbIdx, mSize[mDbIdx], sumN, UB_TWO_BANK_ELEMS_B32 - sumN,
Get<MNK_N>(problemShape_) - sumN,
ubParams_.offsetY[0] + mDbIdx * mSizePing * Get<MNK_N>(problemShape_));
AscendC::SetFlag<AscendC::HardEvent::MTE3_V>(mDbIdx);
}
}
} else {
for (uint64_t ndIdx = 0; ndIdx < ubParams_.ndNum; ndIdx++) {
auto ubRes = ndIdx == 0 ? &ubResPing_ : &ubResPong_;
if (ubParams_.validN[ndIdx] > 0) {
AscendC::UnaryRepeatParams repeatParam;
repeatParam.srcBlkStride = 1;
repeatParam.dstBlkStride = 1;
repeatParam.dstRepStride = CeilDiv(ubParams_.singleN, AscendC::ONE_BLK_SIZE / sizeof(YType));
repeatParam.srcRepStride = GMM_BMM_BLOCK_NUM;
AscendC::WaitFlag<AscendC::HardEvent::MTE3_V>(ndIdx);
AscendC::Cast(*ubRes, mmAddUb[ubParams_.validN[0] * ndIdx], AscendC::RoundMode::CAST_RINT,
ubParams_.validN[ndIdx], ubParams_.validM, repeatParam);
AscendC::SetFlag<AscendC::HardEvent::V_MTE3>(ndIdx);
AscendC::WaitFlag<AscendC::HardEvent::V_MTE3>(ndIdx);
CopyOut(*ubRes, ndIdx, ubParams_.validM, ubParams_.validN[ndIdx],
ubParams_.singleN - ubParams_.validN[ndIdx],
Get<MNK_N>(problemShape_) - ubParams_.validN[ndIdx], ubParams_.offsetY[ndIdx]);
AscendC::SetFlag<AscendC::HardEvent::MTE3_V>(ndIdx);
}
}
}
}
}
}
}
#endif