* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file block_epilogue_dequant_swiglu.h
* \brief
*/
#ifndef EPILOGUE_BLOCK_EPILOGUE_DEQUANT_SWIGLU_H
#define EPILOGUE_BLOCK_EPILOGUE_DEQUANT_SWIGLU_H
#if defined(__NPU_ARCH__) && (__NPU_ARCH__ == 3510)
#include "kernel_operator.h"
#include "../utils/common_utils.h"
#include "../utils/device_utils.h"
#include "../utils/status_utils.h"
#include "../utils/tensor_utils.h"
#include "../tile/tile_copy_policy.h"
namespace Cgmct {
namespace Gemm {
namespace Block {
namespace {
constexpr uint32_t Y_INDEX = 0;
constexpr uint32_t X2SCALE_IDX = 2;
constexpr uint32_t X1SCALE_IDX = 3;
constexpr uint32_t MAX_SINGLE_MN_SIZE = 64 * 256;
constexpr uint32_t MAX_SCALE_NUM = 256;
}
static constexpr AscendC::MicroAPI::CastTrait ctInt322Fp32S = {
AscendC::MicroAPI::RegLayout::UNKNOWN, AscendC::MicroAPI::SatMode::UNKNOWN,
AscendC::MicroAPI::MaskMergeMode::ZEROING, AscendC::RoundMode::CAST_RINT};
static constexpr AscendC::MicroAPI::CastTrait ctHalf2Fp32ZeroS = {
AscendC::MicroAPI::RegLayout::ZERO, AscendC::MicroAPI::SatMode::UNKNOWN, AscendC::MicroAPI::MaskMergeMode::ZEROING,
AscendC::RoundMode::UNKNOWN};
static constexpr AscendC::MicroAPI::CastTrait ctHalf2Fp32OneS = {
AscendC::MicroAPI::RegLayout::ONE, AscendC::MicroAPI::SatMode::UNKNOWN, AscendC::MicroAPI::MaskMergeMode::ZEROING,
AscendC::RoundMode::UNKNOWN};
static constexpr AscendC::MicroAPI::DivSpecificMode DIV_MODES = {AscendC::MicroAPI::MaskMergeMode::ZEROING, true};
static constexpr AscendC::MicroAPI::CastTrait CAST_FP32_TO_B16 = {
AscendC::MicroAPI::RegLayout::ZERO, AscendC::MicroAPI::SatMode::NO_SAT, AscendC::MicroAPI::MaskMergeMode::ZEROING,
AscendC::RoundMode::CAST_RINT};
#define QMM_BLOCK_EPILOGUE_DEQUANT_SWIGLU_CLASS_LOCAL_PARAMS \
template <typename L0TileShape_, typename DataTypeOut_, typename DataTypeIn_, typename DataTypeX2Scale_, \
typename DataTypeX1Scale_, bool IsTensorList_>
#define QMM_BLOCK_EPILOGUE_DEQUANT_FUNC_LOCAL_PARAMS \
L0TileShape_, DataTypeOut_, DataTypeIn_, DataTypeX2Scale_, DataTypeX1Scale_, IsTensorList_
QMM_BLOCK_EPILOGUE_DEQUANT_SWIGLU_CLASS_LOCAL_PARAMS
class BlockEpilogueDequantSwiglu {
public:
__aicore__ inline BlockEpilogueDequantSwiglu()
{
}
struct Arguments {
GM_ADDR workSpaceGMAddr{nullptr};
GM_ADDR x2ScaleGmAddr{nullptr};
GM_ADDR x1ScaleGmAddr{nullptr};
uint32_t baseM;
uint32_t baseN;
Arguments() = default;
};
using Params = Arguments;
using DataTypeOut = DataTypeOut_;
using DataTypeIn = DataTypeIn_;
using DataTypeX1Scale = DataTypeX1Scale_;
using DataTypeX2Scale = DataTypeX2Scale_;
using BlockShape = AscendC::Shape<int64_t, int64_t, int64_t, int64_t>;
using BlockCoord = AscendC::Coord<int64_t, int64_t, int64_t, int64_t, int64_t>;
using ProblemShape = AscendC::Shape<int64_t, int64_t, int64_t>;
public:
__aicore__ inline void Init(Params const ¶ms);
__aicore__ inline auto GetFirstL0c2UbTensor();
__aicore__ inline auto GetSecondL0c2UbTensor();
__aicore__ inline void operator()(const BlockShape &blockShape, const BlockCoord &blockCoord);
__aicore__ inline void UpdateGlobalAddr(const BlockCoord &baseOffset);
__aicore__ inline void UpdateNextProblem(const ProblemShape &problemShape);
private:
__aicore__ inline void CopyOutputFromUb2Gm(uint64_t blockCount, uint64_t offset,
AscendC::LocalTensor<DataTypeOut> &src);
__aicore__ inline void VFDoSwiglu(__ubuf__ float *firstSrc, __ubuf__ float *secondSrc, uint16_t mSize,
uint16_t nSize);
__aicore__ inline void VFDoDequantWithX1X2Scale(__ubuf__ float *l0cOutUbFirstFloatAddr,
__ubuf__ float *l0cOutUbSecondFloatAddr, uint16_t mSize);
__aicore__ inline void VFDoDequant(__ubuf__ float *dst, __ubuf__ DataTypeIn *l0cOut,
__ubuf__ DataTypeX2Scale *scale2, __ubuf__ DataTypeX1Scale *x1Scale,
uint16_t mSize, uint16_t nSize);
__aicore__ inline void CopyX1ScaleFromGm2Ub(AscendC::LocalTensor<DataTypeX1Scale> &dst, uint64_t blockLen,
uint64_t offset);
__aicore__ inline void CopyX2ScaleFromGm2Ub(AscendC::LocalTensor<DataTypeX2Scale> &dst, uint64_t offset);
__aicore__ inline void CopyDataFromGm2Ub(uint64_t halfSingleM, uint64_t singleMInVec);
__aicore__ inline void VFDoPertokenDequantAndSwiglu(uint16_t mSize);
AscendC::GlobalTensor<DataTypeOut> gluResGlobal_;
AscendC::GlobalTensor<DataTypeX1Scale> x1ScaleGlobal_;
AscendC::GlobalTensor<DataTypeX2Scale> x2ScaleGlobal_;
AscendC::LocalTensor<DataTypeIn> l0cOutUbFirst_{AscendC::TPosition::VECIN, 0, MAX_SINGLE_MN_SIZE};
AscendC::LocalTensor<DataTypeIn> l0cOutUbSecond_{AscendC::TPosition::VECIN, MAX_SINGLE_MN_SIZE * sizeof(DataTypeIn),
MAX_SINGLE_MN_SIZE};
AscendC::LocalTensor<float> l0cOutUbFirstFloat_{AscendC::TPosition::VECIN, 0, MAX_SINGLE_MN_SIZE};
AscendC::LocalTensor<float> l0cOutUbSecondFloat_{AscendC::TPosition::VECIN, MAX_SINGLE_MN_SIZE * sizeof(float),
MAX_SINGLE_MN_SIZE};
AscendC::LocalTensor<DataTypeOut> gluRes_;
AscendC::LocalTensor<DataTypeX2Scale> x2ScaleUbFirst_;
AscendC::LocalTensor<DataTypeX2Scale> x2ScaleUbSecond_;
AscendC::LocalTensor<DataTypeX1Scale> x1ScaleUb_;
const Params *params_;
int64_t n_;
uint32_t subBlockIdx_ = AscendC::GetSubBlockIdx();
uint32_t singleM_;
uint32_t singleN_;
BlockCoord blockCoord_{0, 0, 0, 0, 0};
};
QMM_BLOCK_EPILOGUE_DEQUANT_SWIGLU_CLASS_LOCAL_PARAMS
__aicore__ inline void
BlockEpilogueDequantSwiglu<QMM_BLOCK_EPILOGUE_DEQUANT_FUNC_LOCAL_PARAMS>::Init(Params const ¶ms)
{
if ASCEND_IS_AIC {
return;
}
params_ = ¶ms;
constexpr uint32_t afterIn = 2 * MAX_SINGLE_MN_SIZE * sizeof(DataTypeIn);
gluRes_ = AscendC::LocalTensor<DataTypeOut>(AscendC::TPosition::VECOUT, afterIn, MAX_SINGLE_MN_SIZE);
constexpr uint32_t afterGlu = afterIn + MAX_SINGLE_MN_SIZE * sizeof(DataTypeOut);
x2ScaleUbFirst_ = AscendC::LocalTensor<DataTypeX2Scale>(AscendC::TPosition::VECOUT, afterGlu, MAX_SCALE_NUM);
constexpr uint32_t afterGluAndHalfX2 = afterGlu + MAX_SCALE_NUM * sizeof(DataTypeX2Scale);
x2ScaleUbSecond_ =
AscendC::LocalTensor<DataTypeX2Scale>(AscendC::TPosition::VECOUT, afterGluAndHalfX2, MAX_SCALE_NUM);
constexpr uint32_t afterGluAndX2 = afterGluAndHalfX2 + MAX_SCALE_NUM * sizeof(DataTypeX2Scale);
x1ScaleUb_ = AscendC::LocalTensor<DataTypeX1Scale>(AscendC::TPosition::VECOUT, afterGluAndX2, MAX_SCALE_NUM);
}
QMM_BLOCK_EPILOGUE_DEQUANT_SWIGLU_CLASS_LOCAL_PARAMS
__aicore__ inline void
BlockEpilogueDequantSwiglu<QMM_BLOCK_EPILOGUE_DEQUANT_FUNC_LOCAL_PARAMS>::UpdateGlobalAddr(const BlockCoord &baseOffset)
{
if ASCEND_IS_AIV {
gluResGlobal_.SetGlobalBuffer((__gm__ DataTypeOut *)params_->workSpaceGMAddr + Get<Y_INDEX>(baseOffset));
x1ScaleGlobal_.SetGlobalBuffer((__gm__ DataTypeX1Scale *)params_->x1ScaleGmAddr + Get<X1SCALE_IDX>(baseOffset));
x2ScaleGlobal_.SetGlobalBuffer(GetTensorAddr<DataTypeX2Scale>(0, params_->x2ScaleGmAddr) +
Get<X2SCALE_IDX>(baseOffset));
}
}
QMM_BLOCK_EPILOGUE_DEQUANT_SWIGLU_CLASS_LOCAL_PARAMS
__aicore__ inline void BlockEpilogueDequantSwiglu<QMM_BLOCK_EPILOGUE_DEQUANT_FUNC_LOCAL_PARAMS>::UpdateNextProblem(
const ProblemShape &problemShape)
{
n_ = Get<MNK_N>(problemShape);
}
QMM_BLOCK_EPILOGUE_DEQUANT_SWIGLU_CLASS_LOCAL_PARAMS
__aicore__ inline void BlockEpilogueDequantSwiglu<QMM_BLOCK_EPILOGUE_DEQUANT_FUNC_LOCAL_PARAMS>::CopyOutputFromUb2Gm(
uint64_t blockCount, uint64_t offset, AscendC::LocalTensor<DataTypeOut> &src)
{
AscendC::DataCopyExtParams ub2GmParams{1, 0, 0, 0, 0};
ub2GmParams.blockCount = blockCount;
ub2GmParams.blockLen = singleN_ * sizeof(DataTypeOut);
ub2GmParams.dstStride = (n_ - singleN_) * sizeof(DataTypeOut);
AscendC::DataCopyPad(gluResGlobal_[offset], src, ub2GmParams);
}
QMM_BLOCK_EPILOGUE_DEQUANT_SWIGLU_CLASS_LOCAL_PARAMS
__aicore__ inline void BlockEpilogueDequantSwiglu<QMM_BLOCK_EPILOGUE_DEQUANT_FUNC_LOCAL_PARAMS>::VFDoSwiglu(
__ubuf__ float *firstSrc, __ubuf__ float *secondSrc, uint16_t mSize, uint16_t nSize)
{
constexpr uint16_t sizePerRepeat = AscendC::VECTOR_REG_WIDTH / sizeof(DataTypeIn);
uint16_t OneRowRepeatTimes = CeilDiv(static_cast<uint64_t>(nSize), static_cast<uint64_t>(sizePerRepeat));
uint32_t nSrcUbAligned = CeilAlign(nSize, AscendC::ONE_BLK_SIZE / sizeof(DataTypeIn));
uint32_t nDstUbAligned = CeilAlign(nSize, AscendC::ONE_BLK_SIZE / sizeof(DataTypeOut));
const float scalarOne = 1.0;
__ubuf__ DataTypeOut *gluResAddr = (__ubuf__ DataTypeOut *)gluRes_.GetPhyAddr();
__VEC_SCOPE__
{
for (uint16_t mIdx = 0; mIdx < mSize; mIdx++) {
uint32_t elementNum = nSize;
for (uint16_t vfBlockIdx = 0; vfBlockIdx < OneRowRepeatTimes; vfBlockIdx++) {
AscendC::MicroAPI::MaskReg mask = AscendC::MicroAPI::UpdateMask<DataTypeIn>(elementNum);
AscendC::MicroAPI::RegTensor<float> l0cOutRegFirst, l0cOutRegSecond;
AscendC::MicroAPI::RegTensor<DataTypeOut> verg7;
AscendC::MicroAPI::RegTensor<float> verg1, verg2, verg3, verg4, verg5, verg6, swishOutput;
uint32_t l0cOutOffset = mIdx * nSrcUbAligned + vfBlockIdx * sizePerRepeat;
AscendC::MicroAPI::DataCopy(l0cOutRegFirst, firstSrc + l0cOutOffset);
AscendC::MicroAPI::Muls(verg2, l0cOutRegFirst, -(scalarOne), mask);
AscendC::MicroAPI::Exp(verg3, verg2, mask);
AscendC::MicroAPI::Adds(verg4, verg3, scalarOne, mask);
AscendC::MicroAPI::Div<float, &DIV_MODES>(swishOutput, l0cOutRegFirst, verg4, mask);
AscendC::MicroAPI::DataCopy(l0cOutRegSecond, secondSrc + l0cOutOffset);
AscendC::MicroAPI::Mul(verg6, swishOutput, l0cOutRegSecond, mask);
uint32_t dstUbOffset = mIdx * nDstUbAligned + vfBlockIdx * sizePerRepeat;
if constexpr (IsSameType<DataTypeOut, half>::value) {
AscendC::MicroAPI::Cast<half, float, CAST_FP32_TO_B16>(verg7, verg6, mask);
AscendC::MicroAPI::DataCopy<DataTypeOut, AscendC::MicroAPI::StoreDist::DIST_PACK_B32>(
gluResAddr + dstUbOffset, verg7, mask);
} else if constexpr (IsSameType<DataTypeOut, bfloat16_t>::value) {
AscendC::MicroAPI::Cast<bfloat16_t, float, CAST_FP32_TO_B16>(verg7, verg6, mask);
AscendC::MicroAPI::DataCopy<DataTypeOut, AscendC::MicroAPI::StoreDist::DIST_PACK_B32>(
gluResAddr + dstUbOffset, verg7, mask);
} else if constexpr (IsSameType<DataTypeOut, float>::value) {
verg7 = verg6;
AscendC::MicroAPI::DataCopy<DataTypeOut, AscendC::MicroAPI::StoreDist::DIST_NORM_B32>(
gluResAddr + dstUbOffset, verg7, mask);
}
}
}
}
}
QMM_BLOCK_EPILOGUE_DEQUANT_SWIGLU_CLASS_LOCAL_PARAMS
__aicore__ inline void
BlockEpilogueDequantSwiglu<QMM_BLOCK_EPILOGUE_DEQUANT_FUNC_LOCAL_PARAMS>::VFDoDequantWithX1X2Scale(
__ubuf__ float *l0cOutUbFirstFloatAddr, __ubuf__ float *l0cOutUbSecondFloatAddr, uint16_t mSize)
{
__ubuf__ DataTypeIn *l0cOutUbFirstAddr = (__ubuf__ DataTypeIn *)l0cOutUbFirst_.GetPhyAddr();
__ubuf__ DataTypeIn *l0cOutUbSecondAddr = (__ubuf__ DataTypeIn *)l0cOutUbSecond_.GetPhyAddr();
__ubuf__ DataTypeX1Scale *ptScaleUbAddr = (__ubuf__ DataTypeX1Scale *)x1ScaleUb_.GetPhyAddr();
VFDoDequant(l0cOutUbFirstFloatAddr, l0cOutUbFirstAddr, (__ubuf__ DataTypeX2Scale *)x2ScaleUbFirst_.GetPhyAddr(),
ptScaleUbAddr, mSize, singleN_);
VFDoDequant(l0cOutUbSecondFloatAddr, l0cOutUbSecondAddr, (__ubuf__ DataTypeX2Scale *)x2ScaleUbSecond_.GetPhyAddr(),
ptScaleUbAddr, mSize, singleN_);
}
QMM_BLOCK_EPILOGUE_DEQUANT_SWIGLU_CLASS_LOCAL_PARAMS
__aicore__ inline void BlockEpilogueDequantSwiglu<QMM_BLOCK_EPILOGUE_DEQUANT_FUNC_LOCAL_PARAMS>::VFDoDequant(
__ubuf__ float *dst, __ubuf__ DataTypeIn *l0cOut, __ubuf__ DataTypeX2Scale *scale2,
__ubuf__ DataTypeX1Scale *x1Scale, uint16_t mSize, uint16_t nSize)
{
uint32_t eleNumPerVf = AscendC::VECTOR_REG_WIDTH / sizeof(DataTypeIn);
uint32_t nSrcUbAligned = Align(nSize, static_cast<uint16_t>(UB_ALIGN_SIZE / sizeof(DataTypeIn)));
uint32_t nDstUbAligned = nSrcUbAligned;
uint16_t nLoopCnt = (nSize + eleNumPerVf - 1) / eleNumPerVf;
__VEC_SCOPE__
{
AscendC::MicroAPI::MaskReg maskN4B16 =
AscendC::MicroAPI::CreateMask<DataTypeX2Scale, AscendC::MicroAPI::MaskPattern::ALL>();
for (uint16_t mIdx = 0; mIdx < mSize; mIdx++) {
uint32_t elementNum = nSize;
for (uint16_t vfBlockIdx = 0; vfBlockIdx < nLoopCnt; vfBlockIdx++) {
AscendC::MicroAPI::RegTensor<DataTypeIn> l0cOutReg;
AscendC::MicroAPI::RegTensor<DataTypeX2Scale> scaleReg;
AscendC::MicroAPI::RegTensor<DataTypeX1Scale> perTokenScaleReg;
AscendC::MicroAPI::RegTensor<float> l0cOutRegFloat;
AscendC::MicroAPI::RegTensor<float> castScaleReg, castScaleOneReg, mulScaleOutReg, mulPtScaleOutReg;
AscendC::MicroAPI::MaskReg maskN = AscendC::MicroAPI::UpdateMask<float>(elementNum);
uint32_t l0cOutOffset = mIdx * nSrcUbAligned + vfBlockIdx * eleNumPerVf;
AscendC::MicroAPI::DataCopy(l0cOutReg, l0cOut + l0cOutOffset);
if constexpr (IsSameType<DataTypeIn, int32_t>::value) {
AscendC::MicroAPI::Cast<float, DataTypeIn, ctInt322Fp32S>(l0cOutRegFloat, l0cOutReg, maskN);
} else {
l0cOutRegFloat = l0cOutReg;
}
AscendC::MicroAPI::DataCopy(scaleReg, scale2 + vfBlockIdx * eleNumPerVf);
if constexpr (IsSameType<DataTypeX2Scale, bfloat16_t>::value ||
IsSameType<DataTypeX2Scale, half>::value) {
AscendC::MicroAPI::Cast<float, DataTypeX2Scale, ctHalf2Fp32ZeroS>(castScaleReg, scaleReg, maskN);
AscendC::MicroAPI::Cast<float, DataTypeX2Scale, ctHalf2Fp32OneS>(castScaleOneReg, scaleReg,
maskN4B16);
AscendC::MicroAPI::Interleave(castScaleReg, castScaleOneReg, castScaleReg, castScaleOneReg);
} else if constexpr (IsSameType<DataTypeX2Scale, float>::value) {
castScaleReg = scaleReg;
}
AscendC::MicroAPI::Mul(mulScaleOutReg, l0cOutRegFloat, castScaleReg, maskN);
AscendC::MicroAPI::DataCopy<DataTypeX1Scale, AscendC::MicroAPI::LoadDist::DIST_BRC_B32>(
perTokenScaleReg, x1Scale + mIdx);
AscendC::MicroAPI::Mul(mulPtScaleOutReg, mulScaleOutReg, perTokenScaleReg, maskN);
uint32_t dstUbOffset = mIdx * nDstUbAligned + vfBlockIdx * eleNumPerVf;
AscendC::MicroAPI::DataCopy<float, AscendC::MicroAPI::StoreDist::DIST_NORM_B32>(
dst + dstUbOffset, mulPtScaleOutReg, maskN);
}
}
}
}
QMM_BLOCK_EPILOGUE_DEQUANT_SWIGLU_CLASS_LOCAL_PARAMS
__aicore__ inline void BlockEpilogueDequantSwiglu<QMM_BLOCK_EPILOGUE_DEQUANT_FUNC_LOCAL_PARAMS>::CopyX1ScaleFromGm2Ub(
AscendC::LocalTensor<DataTypeX1Scale> &dst, uint64_t blockLen, uint64_t offset)
{
DataCopyParams ptScale2UbParams{1, 0, 0, 0};
DataCopyPadParams padParams;
ptScale2UbParams.blockLen = blockLen;
AscendC::DataCopyPad(dst, x1ScaleGlobal_[Get<X1SCALE_IDX>(blockCoord_) + offset], ptScale2UbParams, padParams);
}
QMM_BLOCK_EPILOGUE_DEQUANT_SWIGLU_CLASS_LOCAL_PARAMS
__aicore__ inline void BlockEpilogueDequantSwiglu<QMM_BLOCK_EPILOGUE_DEQUANT_FUNC_LOCAL_PARAMS>::CopyX2ScaleFromGm2Ub(
AscendC::LocalTensor<DataTypeX2Scale> &dst, uint64_t offset)
{
DataCopyParams scale2UbParams{1, 0, 0, 0};
DataCopyPadParams padParams;
scale2UbParams.blockLen = singleN_ * sizeof(DataTypeX2Scale);
AscendC::DataCopyPad(dst, x2ScaleGlobal_[Get<X2SCALE_IDX>(blockCoord_) + offset], scale2UbParams, padParams);
}
QMM_BLOCK_EPILOGUE_DEQUANT_SWIGLU_CLASS_LOCAL_PARAMS
__aicore__ inline void
BlockEpilogueDequantSwiglu<QMM_BLOCK_EPILOGUE_DEQUANT_FUNC_LOCAL_PARAMS>::CopyDataFromGm2Ub(uint64_t halfSingleM,
uint64_t singleMInVec)
{
CopyX2ScaleFromGm2Ub(x2ScaleUbFirst_, 0);
CopyX2ScaleFromGm2Ub(x2ScaleUbSecond_, n_);
uint64_t mOffset = subBlockIdx_ * halfSingleM;
CopyX1ScaleFromGm2Ub(x1ScaleUb_, singleMInVec * sizeof(DataTypeX1Scale), mOffset);
}
QMM_BLOCK_EPILOGUE_DEQUANT_SWIGLU_CLASS_LOCAL_PARAMS
__aicore__ inline auto BlockEpilogueDequantSwiglu<QMM_BLOCK_EPILOGUE_DEQUANT_FUNC_LOCAL_PARAMS>::GetFirstL0c2UbTensor()
{
return l0cOutUbFirst_;
}
QMM_BLOCK_EPILOGUE_DEQUANT_SWIGLU_CLASS_LOCAL_PARAMS
__aicore__ inline auto BlockEpilogueDequantSwiglu<QMM_BLOCK_EPILOGUE_DEQUANT_FUNC_LOCAL_PARAMS>::GetSecondL0c2UbTensor()
{
return l0cOutUbSecond_;
}
QMM_BLOCK_EPILOGUE_DEQUANT_SWIGLU_CLASS_LOCAL_PARAMS
__aicore__ inline void
BlockEpilogueDequantSwiglu<QMM_BLOCK_EPILOGUE_DEQUANT_FUNC_LOCAL_PARAMS>::VFDoPertokenDequantAndSwiglu(uint16_t mSize)
{
__ubuf__ float *l0cOutUbFirstFloatAddr = (__ubuf__ float *)l0cOutUbFirstFloat_.GetPhyAddr();
__ubuf__ float *l0cOutUbSecondFloatAddr = (__ubuf__ float *)l0cOutUbSecondFloat_.GetPhyAddr();
VFDoDequantWithX1X2Scale(l0cOutUbFirstFloatAddr, l0cOutUbSecondFloatAddr, mSize);
VFDoSwiglu(l0cOutUbFirstFloatAddr, l0cOutUbSecondFloatAddr, mSize, singleN_);
}
QMM_BLOCK_EPILOGUE_DEQUANT_SWIGLU_CLASS_LOCAL_PARAMS
__aicore__ inline void
BlockEpilogueDequantSwiglu<QMM_BLOCK_EPILOGUE_DEQUANT_FUNC_LOCAL_PARAMS>::operator()(const BlockShape &blockShape,
const BlockCoord &blockCoord)
{
singleM_ = Get<MNK_M>(blockShape);
singleN_ = Get<MNK_N>(blockShape);
blockCoord_ = blockCoord;
auto halfSingleM = CeilDiv(static_cast<uint64_t>(singleM_), static_cast<uint64_t>(AscendC::GetTaskRation()));
uint64_t singleMInVec = subBlockIdx_ == 1 ? singleM_ - halfSingleM : halfSingleM;
uint64_t yOffset = Get<Y_INDEX>(blockCoord_) + subBlockIdx_ * halfSingleM * n_;
if (singleMInVec == 0) {
return;
}
CopyDataFromGm2Ub(halfSingleM, singleMInVec);
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(0);
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(0);
VFDoPertokenDequantAndSwiglu(singleMInVec);
AscendC::SetFlag<AscendC::HardEvent::V_MTE3>(0);
AscendC::WaitFlag<AscendC::HardEvent::V_MTE3>(0);
CopyOutputFromUb2Gm(singleMInVec, yOffset, gluRes_);
AscendC::SetFlag<AscendC::HardEvent::MTE3_V>(0);
AscendC::WaitFlag<AscendC::HardEvent::MTE3_V>(0);
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(0);
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(0);
return;
}
}
}
}
#endif
#endif