* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file block_epilogue_finalize_routing.h
* \brief
*/
#ifndef BLOCK_EPILOGUE_FINALIZE_ROUTING_H
#define BLOCK_EPILOGUE_FINALIZE_ROUTING_H
#if defined(__NPU_ARCH__) && (__NPU_ARCH__ == 3510)
#if ASC_DEVKIT_MAJOR >= 9
#include "kernel_basic_intf.h"
#else
#include "kernel_operator.h"
#endif
#include "../utils/common_utils.h"
#include "../utils/device_utils.h"
#include "../utils/status_utils.h"
#include "../utils/tensor_utils.h"
#include "../tile/tile_copy_policy.h"
namespace Cgmct {
namespace Gemm {
namespace Block {
namespace {
constexpr int64_t OUT_ELE_NUM_ONE_BLK = 64;
constexpr uint32_t Y_IDX = 0;
constexpr uint32_t LOGIT_INDEX = 4;
constexpr uint32_t MAX_SINGLE_MN = 128 * 256;
constexpr uint32_t HALF_DB_MAX_SINGLE_MN = 32 * 256;
constexpr uint32_t BLOCK_BYTES = 256;
constexpr uint64_t MAX_OUTPUT_M_UB = 32;
constexpr uint64_t BLOCK_ELEMENTS_FP32 = 8;
}
using namespace AscendC;
#define GMM_BLOCK_EPILOGUE_FINALIZE_ROUTING_CLASS_LOCAL_PARAMS \
template <typename DataTypeOut_>
#define GMM_BLOCK_EPILOGUE_FINALIZE_ROUTING_FUNC_LOCAL_PARAMS \
DataTypeOut_
GMM_BLOCK_EPILOGUE_FINALIZE_ROUTING_CLASS_LOCAL_PARAMS
class BlockEpilogueFinalizeRouting {
public:
__aicore__ inline BlockEpilogueFinalizeRouting() {}
struct Arguments {
GM_ADDR yGMAddr{nullptr};
GM_ADDR x2ScaleGmAddr{nullptr};
GM_ADDR x1ScaleGmAddr{nullptr};
GM_ADDR biasGmAddr{nullptr};
GM_ADDR logitGmAddr{nullptr};
GM_ADDR rowIndexGmAddr{nullptr};
int32_t baseM = 256;
int32_t baseN = 256;
Arguments() = default;
};
using Params = Arguments;
using DataTypeOut = DataTypeOut_;
using BlockShape = AscendC::Shape<int64_t, int64_t, int64_t, int64_t>;
using BlockCoord = AscendC::Coord<int64_t, int64_t, int64_t, int64_t, int64_t, int64_t>;
using ProblemShape = AscendC::Shape<int64_t, int64_t, int64_t>;
public:
__aicore__ inline void Init(Params const ¶ms);
__aicore__ inline auto GetL0c2UbTensor();
__aicore__ inline void operator()(const BlockShape& blockShape, const BlockCoord& blockCoord);
__aicore__ inline void UpdateNextProblem(const ProblemShape& problemShape);
__aicore__ inline void UpdateGlobalAddr(const BlockCoord &baseOffset);
private:
__aicore__ inline void CopyInLogit(uint32_t curBaseM, uint64_t offsetM, LocalTensor<float> logitUb);
__aicore__ inline void VFDoLogitMuls(uint32_t offsetRe, uint32_t offsetLogit,
uint16_t repeatTimesLogit, uint16_t repeatTimesRe, __ubuf__ DataTypeOut* outUbAddr, LocalTensor<float> logitUb);
__aicore__ inline void VectorAtomicProcess(uint32_t curBaseN,
uint32_t curVecBaseM, uint64_t offsetM, uint64_t yOffset, LocalTensor<DataTypeOut> yLocal);
AscendC::GlobalTensor<float> logitGlobal_;
AscendC::GlobalTensor<int64_t> rowIndexGlobal_;
AscendC::GlobalTensor<DataTypeOut> yGlobal_;
AscendC::LocalTensor<DataTypeOut> l0cOutUb_{AscendC::TPosition::VECIN, 0, MAX_SINGLE_MN};
AscendC::LocalTensor<float> logitUbPing_;
AscendC::LocalTensor<float> logitUbPong_;
AscendC::LocalTensor<DataTypeOut> outUbPing_;
AscendC::LocalTensor<DataTypeOut> outUbPong_;
const Params *params_;
int64_t n_;
uint32_t subBlockIdx_ = AscendC::GetSubBlockIdx();
uint64_t singleM_;
uint64_t singleN_;
uint64_t alignN_;
uint64_t vlForFloatNumber_;
uint32_t repeatTimesLine_;
uint16_t yCrossPingPongID_ = 0;
uint16_t logitCrossPingPongID_ = 0;
};
GMM_BLOCK_EPILOGUE_FINALIZE_ROUTING_CLASS_LOCAL_PARAMS
__aicore__ inline void BlockEpilogueFinalizeRouting<GMM_BLOCK_EPILOGUE_FINALIZE_ROUTING_FUNC_LOCAL_PARAMS>::Init(Params const ¶ms)
{
if ASCEND_IS_AIC
{
return;
}
params_ = ¶ms;
yGlobal_.SetGlobalBuffer((__gm__ DataTypeOut *)params_->yGMAddr);
uint32_t afterFirstIn = MAX_SINGLE_MN * sizeof(DataTypeOut);
logitUbPing_= AscendC::LocalTensor<float>(AscendC::TPosition::VECIN, afterFirstIn, BLOCK_BYTES);
uint32_t afterSecondIn = afterFirstIn + BLOCK_BYTES * sizeof(DataTypeOut);
logitUbPong_ = AscendC::LocalTensor<float>(AscendC::TPosition::VECIN, afterSecondIn, BLOCK_BYTES);
outUbPing_ = AscendC::LocalTensor<DataTypeOut>(AscendC::TPosition::VECOUT,
afterSecondIn + BLOCK_BYTES * sizeof(DataTypeOut), HALF_DB_MAX_SINGLE_MN);
outUbPong_ = AscendC::LocalTensor<DataTypeOut>(AscendC::TPosition::VECOUT,
afterSecondIn + (BLOCK_BYTES + HALF_DB_MAX_SINGLE_MN) * sizeof(DataTypeOut), HALF_DB_MAX_SINGLE_MN);
}
GMM_BLOCK_EPILOGUE_FINALIZE_ROUTING_CLASS_LOCAL_PARAMS
__aicore__ inline auto BlockEpilogueFinalizeRouting<GMM_BLOCK_EPILOGUE_FINALIZE_ROUTING_FUNC_LOCAL_PARAMS>::GetL0c2UbTensor()
{
return l0cOutUb_;
}
GMM_BLOCK_EPILOGUE_FINALIZE_ROUTING_CLASS_LOCAL_PARAMS
__aicore__ inline void BlockEpilogueFinalizeRouting<GMM_BLOCK_EPILOGUE_FINALIZE_ROUTING_FUNC_LOCAL_PARAMS>::UpdateNextProblem(
const ProblemShape& problemShape)
{
n_ = Get<MNK_N>(problemShape);
}
GMM_BLOCK_EPILOGUE_FINALIZE_ROUTING_CLASS_LOCAL_PARAMS
__aicore__ inline void BlockEpilogueFinalizeRouting<GMM_BLOCK_EPILOGUE_FINALIZE_ROUTING_FUNC_LOCAL_PARAMS>::UpdateGlobalAddr(
const BlockCoord& baseOffset)
{
if ASCEND_IS_AIV {
logitGlobal_.SetGlobalBuffer((__gm__ float *)params_->logitGmAddr + Get<LOGIT_INDEX>(baseOffset));
rowIndexGlobal_.SetGlobalBuffer((__gm__ int64_t *)params_->rowIndexGmAddr + Get<LOGIT_INDEX>(baseOffset));
}
}
GMM_BLOCK_EPILOGUE_FINALIZE_ROUTING_CLASS_LOCAL_PARAMS
__aicore__ inline void BlockEpilogueFinalizeRouting<GMM_BLOCK_EPILOGUE_FINALIZE_ROUTING_FUNC_LOCAL_PARAMS>::CopyInLogit(
uint32_t curBaseM, uint64_t offsetM, LocalTensor<float> logitUb)
{
DataCopyExtParams perTokenScaleParams{1, static_cast<uint32_t>(curBaseM * sizeof(float)), 0, 0, 0};
DataCopyPadExtParams<float> padParams;
DataCopyPad(logitUb, logitGlobal_[offsetM], perTokenScaleParams, padParams);
}
GMM_BLOCK_EPILOGUE_FINALIZE_ROUTING_CLASS_LOCAL_PARAMS
__aicore__ inline void BlockEpilogueFinalizeRouting<GMM_BLOCK_EPILOGUE_FINALIZE_ROUTING_FUNC_LOCAL_PARAMS>::VectorAtomicProcess(
uint32_t curBaseN, uint32_t curVecBaseM, uint64_t offsetM, uint64_t yOffset, LocalTensor<DataTypeOut> yLocal)
{
SetAtomicAdd<float>();
DataCopyExtParams paramsOut{1, static_cast<uint32_t>(curBaseN * sizeof(DataTypeOut)), 0, 0, 0};
for (uint32_t i = 0; i < curVecBaseM; i++)
{
auto outRow = static_cast<uint64_t>(rowIndexGlobal_.GetValue(offsetM + i));
DataCopyPad(yGlobal_[outRow * n_ + yOffset], yLocal[i * alignN_], paramsOut);
}
SetAtomicNone();
}
GMM_BLOCK_EPILOGUE_FINALIZE_ROUTING_CLASS_LOCAL_PARAMS
__aicore__ inline void BlockEpilogueFinalizeRouting<GMM_BLOCK_EPILOGUE_FINALIZE_ROUTING_FUNC_LOCAL_PARAMS>::VFDoLogitMuls(
uint32_t offsetRe, uint32_t offsetLogit, uint16_t repeatTimesLogit,uint16_t repeatTimesRe, __ubuf__ DataTypeOut* outUbAddr,
LocalTensor<float> logitUb)
{
__ubuf__ DataTypeOut* l0cOutUbAddr = (__ubuf__ DataTypeOut*)l0cOutUb_.GetPhyAddr();
__ubuf__ DataTypeOut* logitUbAddr = (__ubuf__ DataTypeOut*)logitUb.GetPhyAddr();
l0cOutUbAddr += offsetRe;
logitUbAddr += offsetLogit;
__VEC_SCOPE__
{
AscendC::MicroAPI::RegTensor<float> vregLogit;
AscendC::MicroAPI::RegTensor<DataTypeOut> vregRe, vDstReg;
AscendC::MicroAPI::MaskReg mask;
for (uint16_t i = 0; i < repeatTimesLogit; ++i) {
DataCopy<DataTypeOut, MicroAPI::LoadDist::DIST_BRC_B32>(vregLogit, logitUbAddr + i);
uint32_t elementNum = singleN_;
for (uint16_t j = 0; j < repeatTimesRe; ++j) {
mask = AscendC::MicroAPI::UpdateMask<DataTypeOut>(elementNum);
AscendC::MicroAPI::DataCopy(vregRe, l0cOutUbAddr + i * alignN_ + j * vlForFloatNumber_);
AscendC::MicroAPI::Mul(vDstReg, vregLogit, vregRe, mask);
AscendC::MicroAPI::DataCopy(outUbAddr + i * alignN_ + j * vlForFloatNumber_, vDstReg, mask);
}
}
}
}
GMM_BLOCK_EPILOGUE_FINALIZE_ROUTING_CLASS_LOCAL_PARAMS
__aicore__ inline void
BlockEpilogueFinalizeRouting<GMM_BLOCK_EPILOGUE_FINALIZE_ROUTING_FUNC_LOCAL_PARAMS>::operator()(
const BlockShape& blockShape, const BlockCoord& blockCoord)
{
singleM_ = Get<MNK_M>(blockShape);
singleN_ = Get<MNK_N>(blockShape);
uint32_t halfSingleM = CeilDiv(singleM_,
static_cast<uint64_t>(AscendC::GetTaskRation()));
alignN_ = CeilDiv(singleN_, BLOCK_ELEMENTS_FP32) * BLOCK_ELEMENTS_FP32;
uint64_t singleMInVec = subBlockIdx_ == 1 ? singleM_ - halfSingleM : halfSingleM;
if (singleMInVec == 0) {
return;
}
uint32_t mOffset = subBlockIdx_ * halfSingleM;
vlForFloatNumber_ = BLOCK_BYTES / sizeof(DataTypeOut);
auto repeatTimesRe = CeilDiv(singleN_, vlForFloatNumber_);
uint64_t logitOffset = Get<LOGIT_INDEX>(blockCoord) + mOffset;
uint64_t yOffset = Get<Y_IDX>(blockCoord);
uint16_t remainRepeatTimesLogit = (singleMInVec % MAX_OUTPUT_M_UB != 0)
? singleMInVec % MAX_OUTPUT_M_UB : MAX_OUTPUT_M_UB;
auto logitUb = logitCrossPingPongID_ == 0 ? logitUbPing_ : logitUbPong_;
CopyInLogit(singleMInVec, logitOffset, logitUb);
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(logitCrossPingPongID_ );
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(logitCrossPingPongID_ );
logitCrossPingPongID_ = (logitCrossPingPongID_ + 1) & 1;
uint32_t loopNumY = CeilDiv(singleMInVec, MAX_OUTPUT_M_UB);
AscendC::SetFlag<AscendC::HardEvent::MTE3_V>(0);
AscendC::SetFlag<AscendC::HardEvent::MTE3_V>(1);
for (int32_t i = 0; i < loopNumY; i++)
{
AscendC::WaitFlag<AscendC::HardEvent::MTE3_V>(yCrossPingPongID_);
repeatTimesLine_ = (i == loopNumY - 1) ? remainRepeatTimesLogit : MAX_OUTPUT_M_UB;
__ubuf__ DataTypeOut* outUbAddr =
yCrossPingPongID_ == 0 ? (__ubuf__ DataTypeOut*)outUbPing_.GetPhyAddr() : (__ubuf__ DataTypeOut*)outUbPong_.GetPhyAddr();
VFDoLogitMuls(i * MAX_OUTPUT_M_UB * alignN_, i * MAX_OUTPUT_M_UB, repeatTimesLine_, repeatTimesRe, outUbAddr, logitUb);
AscendC::SetFlag<AscendC::HardEvent::V_MTE3>(yCrossPingPongID_ );
AscendC::WaitFlag<AscendC::HardEvent::V_MTE3>(yCrossPingPongID_);
auto yLocal = yCrossPingPongID_ == 0 ? outUbPing_ : outUbPong_;
VectorAtomicProcess(singleN_, repeatTimesLine_, logitOffset + i * MAX_OUTPUT_M_UB, yOffset, yLocal);
AscendC::SetFlag<AscendC::HardEvent::MTE3_V>(yCrossPingPongID_);
yCrossPingPongID_= (yCrossPingPongID_+ 1) & 1;
}
AscendC::WaitFlag<AscendC::HardEvent::MTE3_V>(0);
AscendC::WaitFlag<AscendC::HardEvent::MTE3_V>(1);
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(logitCrossPingPongID_);
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(logitCrossPingPongID_);
return;
}
}
}
}
#endif
#endif