* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file block_epilogue_dequant.h
* \brief
*/
#ifndef EPILOGUE_BLOCK_EPILOGUE_DEQUANT_H
#define EPILOGUE_BLOCK_EPILOGUE_DEQUANT_H
#if defined(__NPU_ARCH__) && (__NPU_ARCH__ == 3510)
#if ASC_DEVKIT_MAJOR >= 9
#include "kernel_basic_intf.h"
#else
#include "kernel_operator.h"
#endif
#include "../utils/common_utils.h"
#include "../utils/device_utils.h"
#include "../utils/status_utils.h"
#include "../utils/tensor_utils.h"
#include "../tile/tile_copy_policy.h"
namespace Cgmct {
namespace Gemm {
namespace Block {
namespace Qmm {
enum class QuantMode : uint32_t {
DEFAULT = 0x0U,
PERTENSOR_MODE = 0x1U,
PERCHANNEL_MODE = 0x1U << 1,
PERTOKEN_MODE = 0x1U << 2,
MX_PERGROUP_MODE = 0x1U << 3,
PERBLOCK_MODE = 0x1U << 4,
};
struct DequantTiling {
uint32_t baseM;
uint32_t baseN;
QuantMode x1QuantMode = QuantMode::DEFAULT;
QuantMode x2QuantMode = QuantMode::DEFAULT;
uint32_t biasDtype = DT_FLOAT;
bool isBiasEpilogue = false;
};
}
namespace {
constexpr uint32_t FP32_OUTPUT_TIMES = 4;
constexpr uint32_t Y_IDX = 0;
constexpr uint32_t X2SCALE_IDX = 1;
constexpr uint32_t X1SCALE_IDX = 2;
constexpr uint32_t BIAS_IDX = 3;
constexpr uint32_t M_IDX = 0;
constexpr uint32_t N_IDX = 1;
}
constexpr AscendC::MicroAPI::CastTrait ctInt322Fp32 = {
AscendC::MicroAPI::RegLayout::UNKNOWN, AscendC::MicroAPI::SatMode::UNKNOWN,
AscendC::MicroAPI::MaskMergeMode::ZEROING, AscendC::RoundMode::CAST_RINT};
constexpr AscendC::MicroAPI::CastTrait ctFp322Half = {
AscendC::MicroAPI::RegLayout::ZERO, AscendC::MicroAPI::SatMode::NO_SAT, AscendC::MicroAPI::MaskMergeMode::ZEROING,
AscendC::RoundMode::CAST_RINT};
constexpr AscendC::MicroAPI::CastTrait ctHalf2Fp32Zero = {
AscendC::MicroAPI::RegLayout::ZERO, AscendC::MicroAPI::SatMode::UNKNOWN, AscendC::MicroAPI::MaskMergeMode::ZEROING,
AscendC::RoundMode::UNKNOWN};
constexpr AscendC::MicroAPI::CastTrait ctHalf2Fp32One = {
AscendC::MicroAPI::RegLayout::ONE, AscendC::MicroAPI::SatMode::UNKNOWN, AscendC::MicroAPI::MaskMergeMode::ZEROING,
AscendC::RoundMode::UNKNOWN};
#define QMM_BLOCK_EPILOGUE_DEQUANT_CLASS_LOCAL_PARAMS \
template <typename L0TileShape_, typename DataTypeOut_, typename DataTypeIn_, typename DataTypeX2Scale_, \
typename DataTypeX1Scale_, typename DataTypeBias_, bool IsTensorList_>
#define QMM_BLOCK_EPILOGUE_DEQUANT_FUNC_LOCAL_PARAMS \
L0TileShape_, DataTypeOut_, DataTypeIn_, DataTypeX2Scale_, DataTypeX1Scale_, DataTypeBias_, IsTensorList_
QMM_BLOCK_EPILOGUE_DEQUANT_CLASS_LOCAL_PARAMS
class BlockEpilogueDequant {
public:
__aicore__ inline BlockEpilogueDequant();
__aicore__ inline ~BlockEpilogueDequant();
struct Arguments {
GM_ADDR yGmAddr{nullptr};
GM_ADDR x2ScaleGmAddr{nullptr};
GM_ADDR x1ScaleGmAddr{nullptr};
GM_ADDR biasGmAddr{nullptr};
const Qmm::DequantTiling dequantTiling;
};
struct Params {
GM_ADDR yGmAddr{nullptr};
GM_ADDR x2ScaleGmAddr{nullptr};
GM_ADDR x1ScaleGmAddr{nullptr};
GM_ADDR biasGmAddr{nullptr};
const Qmm::DequantTiling dequantTiling;
};
using DataTypeOut = DataTypeOut_;
using DataTypeIn = DataTypeIn_;
using DataTypeX1Scale = DataTypeX1Scale_;
using DataTypeX2Scale = DataTypeX2Scale_;
using DataTypeBias = DataTypeBias_;
static constexpr int64_t l0M = GetIntegralConstant<MNK_M, L0TileShape_>();
static constexpr int64_t l0N = GetIntegralConstant<MNK_N, L0TileShape_>();
using BlockShape = Shape<int64_t, int64_t, int64_t, int64_t>;
using BaseOffset = Coord<int64_t, int64_t, int64_t, int64_t>;
using BlockCoord = Coord<int64_t, int64_t, int64_t, int64_t>;
using ProblemShape = Shape<int64_t, int64_t, int64_t>;
public:
__aicore__ inline void Init(Params const& params, ProblemShape& problemShape);
__aicore__ inline void UpdateGlobalBuffer(Params const& params);
__aicore__ inline void UpdateGroupedParams(Params const& params, BaseOffset const& offset, uint32_t groupIdx);
__aicore__ inline void UpdateBatchParams(Params const& params, BaseOffset const& offset, uint32_t bacthIdx);
__aicore__ inline void Run(BlockShape& blockShape, BlockCoord& blockCoord);
__aicore__ inline auto GetL0c2UbTensor();
__aicore__ inline void SetOutL2CacheHint();
__aicore__ inline void operator()(BlockShape& blockShape, BlockCoord& blockCoord);
private:
__aicore__ inline void UpdateTensorListGlobalBuffer(Params const& params);
__aicore__ inline void UpdateTensorGlobalBuffer(Params const& params);
__aicore__ inline void CopyDataFromGm2Ub();
__aicore__ inline void CopyX1ScaleFromGm2Ub(AscendC::LocalTensor<DataTypeX1Scale>& dst, uint64_t blockLen,
uint64_t offset);
__aicore__ inline void CopyX2ScaleFromGm2Ub(AscendC::LocalTensor<DataTypeX2Scale>& dst);
template <class BiasDtype>
__aicore__ inline void CopyBiasFromGm2Ub(AscendC::LocalTensor<BiasDtype>& dst);
__aicore__ inline void CopyDequantResFromUb2Gm(uint64_t blockCount, uint64_t offset,
AscendC::LocalTensor<DataTypeOut>& src);
__aicore__ inline void UbSetFlag();
__aicore__ inline void VFDoDequantWithX1Pertoken(__ubuf__ DataTypeOut* dequantOutInUbAddr,
__ubuf__ DataTypeIn* l0cOutUbAddr, uint64_t offsetPtScale,
uint16_t mSize);
__aicore__ inline void VFDoDequantWithX1Pertensor(__ubuf__ DataTypeOut* dequantOutInUbAddr,
__ubuf__ DataTypeIn* l0cOutUbAddr, uint16_t mSize);
__aicore__ inline void VFDoDequantWithoutX1Scale(__ubuf__ DataTypeOut* dequantOutInUbAddr,
__ubuf__ DataTypeIn* l0cOutUbAddr, uint16_t mSize);
template <bool isPertensor, Qmm::QuantMode x1QuantMode, bool isBiasEpilogue, class BiasDtype>
__aicore__ inline void VFDoDequant(__ubuf__ DataTypeOut* dst, __ubuf__ DataTypeIn* l0cOut,
__ubuf__ DataTypeX2Scale* scale2, __ubuf__ DataTypeX1Scale* x1Scale,
__ubuf__ BiasDtype* bias, uint16_t mSize, uint16_t nSize);
AscendC::GlobalTensor<DataTypeOut> yGlobal_;
AscendC::GlobalTensor<float> biasGlobalFloat_;
AscendC::GlobalTensor<bfloat16_t> biasGlobalB16_;
AscendC::GlobalTensor<DataTypeX2Scale> x2ScaleGlobal_;
AscendC::GlobalTensor<DataTypeX1Scale> x1ScaleGlobal_;
AscendC::LocalTensor<DataTypeIn> l0cOutUb_;
AscendC::LocalTensor<DataTypeX2Scale> x2ScaleUb_;
AscendC::LocalTensor<DataTypeX1Scale> x1ScaleUb_;
AscendC::LocalTensor<bfloat16_t> biasUbB16_;
AscendC::LocalTensor<float> biasUbFloat_;
AscendC::LocalTensor<DataTypeOut> dequantOutInUBPing_;
AscendC::LocalTensor<DataTypeOut> dequantOutInUBPong_;
float x2ScaleScalar_;
float x1ScaleScalar_;
const Qmm::DequantTiling* dequantTiling_;
ProblemShape problemShape_;
uint32_t biasDtype_ = DT_FLOAT;
uint32_t groupIdx_ = 0;
uint32_t subBlockIdx_ = AscendC::GetSubBlockIdx();
uint32_t singleM_;
uint32_t singleN_;
uint32_t dequantOutInUBPingPongID_ = 0;
bool isBiasEpilogue_ = false;
constexpr static uint16_t INPUT_BUFFER_FLAG_0 = 0;
constexpr static uint16_t INPUT_BUFFER_FLAG_1 = 1;
constexpr static uint16_t INPUT_BUFFER_FLAG_2 = 2;
constexpr static uint16_t OUTPUT_BUFFER_FLAG_0 = 0;
constexpr static uint16_t OUTPUT_BUFFER_FLAG_1 = 1;
BaseOffset baseOffset_{0, 0, 0, 0};
BlockCoord blockCoord_{0, 0, 0, 0};
};
QMM_BLOCK_EPILOGUE_DEQUANT_CLASS_LOCAL_PARAMS
__aicore__ inline
BlockEpilogueDequant<QMM_BLOCK_EPILOGUE_DEQUANT_FUNC_LOCAL_PARAMS>::BlockEpilogueDequant()
{
if ASCEND_IS_AIV {
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(INPUT_BUFFER_FLAG_0);
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(INPUT_BUFFER_FLAG_1);
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(INPUT_BUFFER_FLAG_2);
AscendC::SetFlag<AscendC::HardEvent::MTE3_V>(OUTPUT_BUFFER_FLAG_0);
AscendC::SetFlag<AscendC::HardEvent::MTE3_V>(OUTPUT_BUFFER_FLAG_1);
}
}
QMM_BLOCK_EPILOGUE_DEQUANT_CLASS_LOCAL_PARAMS
__aicore__ inline
BlockEpilogueDequant<QMM_BLOCK_EPILOGUE_DEQUANT_FUNC_LOCAL_PARAMS>::~BlockEpilogueDequant()
{
if ASCEND_IS_AIV {
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(INPUT_BUFFER_FLAG_0);
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(INPUT_BUFFER_FLAG_1);
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(INPUT_BUFFER_FLAG_2);
AscendC::WaitFlag<AscendC::HardEvent::MTE3_V>(OUTPUT_BUFFER_FLAG_0);
AscendC::WaitFlag<AscendC::HardEvent::MTE3_V>(OUTPUT_BUFFER_FLAG_1);
}
}
QMM_BLOCK_EPILOGUE_DEQUANT_CLASS_LOCAL_PARAMS
__aicore__ inline void
BlockEpilogueDequant<QMM_BLOCK_EPILOGUE_DEQUANT_FUNC_LOCAL_PARAMS>::UpdateTensorListGlobalBuffer(Params const& params)
{
if (dequantTiling_->x2QuantMode == Qmm::QuantMode::PERTENSOR_MODE) {
DataTypeX2Scale x2ScaleValue =
*((__gm__ DataTypeX2Scale*)(GetTensorAddr<DataTypeX2Scale>(0, params.x2ScaleGmAddr) + groupIdx_));
if constexpr (AscendC::IsSameType<DataTypeX2Scale, bfloat16_t>::value) {
x2ScaleScalar_ = AscendC::ToFloat(x2ScaleValue);
} else {
x2ScaleScalar_ = x2ScaleValue;
}
} else {
x2ScaleGlobal_.SetGlobalBuffer(GetTensorAddr<DataTypeX2Scale>(0, params.x2ScaleGmAddr) +
Get<X2SCALE_IDX>(baseOffset_));
}
if (dequantTiling_->x1QuantMode == Qmm::QuantMode::PERTENSOR_MODE) {
x1ScaleScalar_ = *((__gm__ DataTypeX1Scale*)params.x1ScaleGmAddr + groupIdx_);
} else if (dequantTiling_->x1QuantMode == Qmm::QuantMode::PERTOKEN_MODE) {
x1ScaleGlobal_.SetGlobalBuffer((__gm__ DataTypeX1Scale*)params.x1ScaleGmAddr + Get<X1SCALE_IDX>(baseOffset_));
}
if (isBiasEpilogue_) {
if (biasDtype_ == DT_FLOAT) {
biasGlobalFloat_.SetGlobalBuffer(GetTensorAddr<float>(0, params.biasGmAddr) + Get<BIAS_IDX>(baseOffset_));
} else {
biasGlobalB16_.SetGlobalBuffer(GetTensorAddr<bfloat16_t>(0, params.biasGmAddr) +
Get<BIAS_IDX>(baseOffset_));
}
}
yGlobal_.SetGlobalBuffer(GetTensorAddr<DataTypeOut>(0, params.yGmAddr) + Get<Y_IDX>(baseOffset_));
}
QMM_BLOCK_EPILOGUE_DEQUANT_CLASS_LOCAL_PARAMS
__aicore__ inline void
BlockEpilogueDequant<QMM_BLOCK_EPILOGUE_DEQUANT_FUNC_LOCAL_PARAMS>::UpdateTensorGlobalBuffer(Params const& params)
{
if (dequantTiling_->x2QuantMode == Qmm::QuantMode::PERTENSOR_MODE) {
DataTypeX2Scale x2ScaleValue = *((__gm__ DataTypeX2Scale*)params.x2ScaleGmAddr + groupIdx_);
if constexpr (AscendC::IsSameType<DataTypeX2Scale, bfloat16_t>::value) {
x2ScaleScalar_ = AscendC::ToFloat(x2ScaleValue);
} else {
x2ScaleScalar_ = x2ScaleValue;
}
} else {
x2ScaleGlobal_.SetGlobalBuffer((__gm__ DataTypeX2Scale*)params.x2ScaleGmAddr + Get<X2SCALE_IDX>(baseOffset_));
}
if (dequantTiling_->x1QuantMode == Qmm::QuantMode::PERTENSOR_MODE) {
x1ScaleScalar_ = *((__gm__ DataTypeX1Scale*)params.x1ScaleGmAddr + groupIdx_);
} else if (dequantTiling_->x1QuantMode == Qmm::QuantMode::PERTOKEN_MODE) {
x1ScaleGlobal_.SetGlobalBuffer((__gm__ DataTypeX1Scale*)params.x1ScaleGmAddr + Get<X1SCALE_IDX>(baseOffset_));
}
if (isBiasEpilogue_) {
if (biasDtype_ == DT_FLOAT) {
biasGlobalFloat_.SetGlobalBuffer((__gm__ float*)params.biasGmAddr + Get<BIAS_IDX>(baseOffset_));
} else {
biasGlobalB16_.SetGlobalBuffer((__gm__ bfloat16_t*)params.biasGmAddr + Get<BIAS_IDX>(baseOffset_));
}
}
yGlobal_.SetGlobalBuffer((__gm__ DataTypeOut*)params.yGmAddr + Get<Y_IDX>(baseOffset_));
}
QMM_BLOCK_EPILOGUE_DEQUANT_CLASS_LOCAL_PARAMS
__aicore__ inline void
BlockEpilogueDequant<QMM_BLOCK_EPILOGUE_DEQUANT_FUNC_LOCAL_PARAMS>::UpdateGlobalBuffer(Params const& params)
{
if constexpr (IsTensorList_) {
UpdateTensorListGlobalBuffer(params);
} else {
UpdateTensorGlobalBuffer(params);
}
}
QMM_BLOCK_EPILOGUE_DEQUANT_CLASS_LOCAL_PARAMS __aicore__ inline void
BlockEpilogueDequant<QMM_BLOCK_EPILOGUE_DEQUANT_FUNC_LOCAL_PARAMS>::Init(Params const& params,
ProblemShape& problemShape)
{
dequantTiling_ = ¶ms.dequantTiling;
uint64_t mForSingleVec = CeilDiv(dequantTiling_->baseM, AscendC::GetTaskRation());
uint64_t offset = 0;
l0cOutUb_ = AscendC::LocalTensor<DataTypeIn>(AscendC::TPosition::VECIN, offset, mForSingleVec * dequantTiling_->baseN);
offset += mForSingleVec * dequantTiling_->baseN * sizeof(DataTypeIn);
if ASCEND_IS_AIV {
UpdateGlobalBuffer(params);
isBiasEpilogue_ = dequantTiling_->isBiasEpilogue && params.biasGmAddr != nullptr;
biasDtype_ = dequantTiling_->biasDtype;
if (dequantTiling_->x2QuantMode == Qmm::QuantMode::PERCHANNEL_MODE) {
x2ScaleUb_ = AscendC::LocalTensor<DataTypeX2Scale>(AscendC::TPosition::VECIN, offset, dequantTiling_->baseN);
offset += dequantTiling_->baseN * sizeof(DataTypeX2Scale);
}
if (dequantTiling_->x1QuantMode == Qmm::QuantMode::PERTOKEN_MODE) {
x1ScaleUb_ = AscendC::LocalTensor<DataTypeX1Scale>(AscendC::TPosition::VECIN, offset, Align(mForSingleVec * sizeof(DataTypeX1Scale), static_cast<uint64_t>(UB_ALIGN_SIZE)) / sizeof(DataTypeX1Scale));
offset += Align(mForSingleVec * sizeof(DataTypeX1Scale), static_cast<uint64_t>(UB_ALIGN_SIZE));
}
if (isBiasEpilogue_) {
if (biasDtype_ == DT_FLOAT) {
biasUbFloat_ = AscendC::LocalTensor<float>(AscendC::TPosition::VECIN, offset, dequantTiling_->baseN);
offset += dequantTiling_->baseN * sizeof(float);
} else {
biasUbB16_ = AscendC::LocalTensor<bfloat16_t>(AscendC::TPosition::VECIN, offset, dequantTiling_->baseN);
offset += dequantTiling_->baseN * sizeof(bfloat16_t);
}
}
dequantOutInUBPing_ = AscendC::LocalTensor<DataTypeOut>(AscendC::TPosition::VECIN, offset, CeilDiv(mForSingleVec, FP32_OUTPUT_TIMES) * dequantTiling_->baseN);
offset += CeilDiv(mForSingleVec, FP32_OUTPUT_TIMES) * dequantTiling_->baseN * sizeof(DataTypeOut);
dequantOutInUBPong_ = AscendC::LocalTensor<DataTypeOut>(AscendC::TPosition::VECIN, offset, CeilDiv(mForSingleVec, FP32_OUTPUT_TIMES) * dequantTiling_->baseN);
}
problemShape_ = problemShape;
}
QMM_BLOCK_EPILOGUE_DEQUANT_CLASS_LOCAL_PARAMS
__aicore__ inline void BlockEpilogueDequant<QMM_BLOCK_EPILOGUE_DEQUANT_FUNC_LOCAL_PARAMS>::UpdateGroupedParams(
Params const& params, BaseOffset const& offset, uint32_t groupIdx)
{
baseOffset_ = offset;
groupIdx_ = groupIdx;
UpdateGlobalBuffer(params);
}
QMM_BLOCK_EPILOGUE_DEQUANT_CLASS_LOCAL_PARAMS
__aicore__ inline void BlockEpilogueDequant<QMM_BLOCK_EPILOGUE_DEQUANT_FUNC_LOCAL_PARAMS>::UpdateBatchParams(
Params const& params, BaseOffset const& offset, uint32_t bacthIdx)
{
UpdateGroupedParams(params, offset, bacthIdx);
}
QMM_BLOCK_EPILOGUE_DEQUANT_CLASS_LOCAL_PARAMS
__aicore__ inline void BlockEpilogueDequant<QMM_BLOCK_EPILOGUE_DEQUANT_FUNC_LOCAL_PARAMS>::CopyDataFromGm2Ub()
{
auto halfSingleM = CeilDiv(singleM_, AscendC::GetTaskRation());
auto singleMInVec = subBlockIdx_ == 1 ? singleM_ - halfSingleM : halfSingleM;
if (dequantTiling_->x2QuantMode == Qmm::QuantMode::PERCHANNEL_MODE) {
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(INPUT_BUFFER_FLAG_0);
CopyX2ScaleFromGm2Ub(x2ScaleUb_);
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(INPUT_BUFFER_FLAG_0);
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(INPUT_BUFFER_FLAG_0);
}
uint64_t mOffset = subBlockIdx_ * halfSingleM;
if (dequantTiling_->x1QuantMode == Qmm::QuantMode::PERTOKEN_MODE) {
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(INPUT_BUFFER_FLAG_1);
CopyX1ScaleFromGm2Ub(x1ScaleUb_, singleMInVec * sizeof(DataTypeX1Scale),
Get<X1SCALE_IDX>(blockCoord_) + mOffset);
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(INPUT_BUFFER_FLAG_1);
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(INPUT_BUFFER_FLAG_1);
}
if (isBiasEpilogue_) {
if (biasDtype_ == DT_FLOAT) {
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(INPUT_BUFFER_FLAG_2);
CopyBiasFromGm2Ub<float>(biasUbFloat_);
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(INPUT_BUFFER_FLAG_2);
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(INPUT_BUFFER_FLAG_2);
} else {
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(INPUT_BUFFER_FLAG_2);
CopyBiasFromGm2Ub<bfloat16_t>(biasUbB16_);
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(INPUT_BUFFER_FLAG_2);
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(INPUT_BUFFER_FLAG_2);
}
}
}
QMM_BLOCK_EPILOGUE_DEQUANT_CLASS_LOCAL_PARAMS
__aicore__ inline void BlockEpilogueDequant<QMM_BLOCK_EPILOGUE_DEQUANT_FUNC_LOCAL_PARAMS>::CopyX1ScaleFromGm2Ub(
AscendC::LocalTensor<DataTypeX1Scale>& dst, uint64_t blockLen, uint64_t offset)
{
DataCopyParams ptScale2UbParams{1, 0, 0, 0};
DataCopyPadParams padParams;
ptScale2UbParams.blockLen = blockLen;
AscendC::DataCopyPad(dst, x1ScaleGlobal_[offset], ptScale2UbParams, padParams);
}
QMM_BLOCK_EPILOGUE_DEQUANT_CLASS_LOCAL_PARAMS
__aicore__ inline void BlockEpilogueDequant<QMM_BLOCK_EPILOGUE_DEQUANT_FUNC_LOCAL_PARAMS>::CopyX2ScaleFromGm2Ub(
AscendC::LocalTensor<DataTypeX2Scale>& dst)
{
DataCopyParams scale2UbParams{1, 0, 0, 0};
DataCopyPadParams padParams;
scale2UbParams.blockLen = singleN_ * sizeof(DataTypeX2Scale);
AscendC::DataCopyPad(dst, x2ScaleGlobal_[Get<X2SCALE_IDX>(blockCoord_)], scale2UbParams, padParams);
}
QMM_BLOCK_EPILOGUE_DEQUANT_CLASS_LOCAL_PARAMS
template <class BiasDtype>
__aicore__ inline void BlockEpilogueDequant<QMM_BLOCK_EPILOGUE_DEQUANT_FUNC_LOCAL_PARAMS>::CopyBiasFromGm2Ub(
AscendC::LocalTensor<BiasDtype>& dst)
{
DataCopyParams bias2UbParams{1, 0, 0, 0};
DataCopyPadParams padParams;
if constexpr (IsSameType<BiasDtype, float>::value) {
bias2UbParams.blockLen = singleN_ * sizeof(float);
AscendC::DataCopyPad(dst, biasGlobalFloat_[Get<BIAS_IDX>(blockCoord_)], bias2UbParams, padParams);
} else {
bias2UbParams.blockLen = singleN_ * sizeof(bfloat16_t);
AscendC::DataCopyPad(dst, biasGlobalB16_[Get<BIAS_IDX>(blockCoord_)], bias2UbParams, padParams);
}
}
QMM_BLOCK_EPILOGUE_DEQUANT_CLASS_LOCAL_PARAMS
__aicore__ inline void BlockEpilogueDequant<QMM_BLOCK_EPILOGUE_DEQUANT_FUNC_LOCAL_PARAMS>::CopyDequantResFromUb2Gm(
uint64_t blockCount, uint64_t offset, AscendC::LocalTensor<DataTypeOut>& src)
{
DataCopyExtParams ub2GmParams{1, 0, 0, 0, 0};
ub2GmParams.blockLen = singleN_ * sizeof(DataTypeOut);
ub2GmParams.blockCount = blockCount;
ub2GmParams.dstStride = (Get<1>(problemShape_) - singleN_) * sizeof(DataTypeOut);
AscendC::DataCopyPad(yGlobal_[Get<Y_IDX>(blockCoord_) + offset], src, ub2GmParams);
}
QMM_BLOCK_EPILOGUE_DEQUANT_CLASS_LOCAL_PARAMS
__aicore__ inline void BlockEpilogueDequant<QMM_BLOCK_EPILOGUE_DEQUANT_FUNC_LOCAL_PARAMS>::UbSetFlag()
{
if (dequantTiling_->x2QuantMode == Qmm::QuantMode::PERCHANNEL_MODE) {
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(INPUT_BUFFER_FLAG_0);
}
if (dequantTiling_->x1QuantMode == Qmm::QuantMode::PERTOKEN_MODE) {
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(INPUT_BUFFER_FLAG_1);
}
if (isBiasEpilogue_) {
if (biasDtype_ == DT_FLOAT) {
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(INPUT_BUFFER_FLAG_2);
} else {
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(INPUT_BUFFER_FLAG_2);
}
}
}
QMM_BLOCK_EPILOGUE_DEQUANT_CLASS_LOCAL_PARAMS
__aicore__ inline void BlockEpilogueDequant<QMM_BLOCK_EPILOGUE_DEQUANT_FUNC_LOCAL_PARAMS>::VFDoDequantWithX1Pertoken(
__ubuf__ DataTypeOut* dequantOutInUbAddr, __ubuf__ DataTypeIn* l0cOutUbAddr, uint64_t offsetPtScale, uint16_t mSize)
{
__ubuf__ DataTypeX1Scale* ptScaleUbAddr = (__ubuf__ DataTypeX1Scale*)x1ScaleUb_.GetPhyAddr();
ptScaleUbAddr = ptScaleUbAddr + offsetPtScale;
if (!isBiasEpilogue_) {
if (dequantTiling_->x2QuantMode == Qmm::QuantMode::PERTENSOR_MODE) {
VFDoDequant<true, Qmm::QuantMode::PERTOKEN_MODE, false, float>(dequantOutInUbAddr, l0cOutUbAddr, nullptr,
ptScaleUbAddr, nullptr, mSize, singleN_);
} else {
VFDoDequant<false, Qmm::QuantMode::PERTOKEN_MODE, false, float>(
dequantOutInUbAddr, l0cOutUbAddr, (__ubuf__ DataTypeX2Scale*)x2ScaleUb_.GetPhyAddr(), ptScaleUbAddr,
nullptr, mSize, singleN_);
}
} else {
if (biasDtype_ == DT_FLOAT) {
if (dequantTiling_->x2QuantMode == Qmm::QuantMode::PERTENSOR_MODE) {
VFDoDequant<true, Qmm::QuantMode::PERTOKEN_MODE, true, float>(
dequantOutInUbAddr, l0cOutUbAddr, nullptr, ptScaleUbAddr,
(__ubuf__ float*)biasUbFloat_.GetPhyAddr(), mSize, singleN_);
} else {
VFDoDequant<false, Qmm::QuantMode::PERTOKEN_MODE, true, float>(
dequantOutInUbAddr, l0cOutUbAddr, (__ubuf__ DataTypeX2Scale*)x2ScaleUb_.GetPhyAddr(), ptScaleUbAddr,
(__ubuf__ float*)biasUbFloat_.GetPhyAddr(), mSize, singleN_);
}
} else if (biasDtype_ == DT_BF16) {
if (dequantTiling_->x2QuantMode == Qmm::QuantMode::PERTENSOR_MODE) {
VFDoDequant<true, Qmm::QuantMode::PERTOKEN_MODE, true, bfloat16_t>(
dequantOutInUbAddr, l0cOutUbAddr, nullptr, ptScaleUbAddr,
(__ubuf__ bfloat16_t*)biasUbB16_.GetPhyAddr(), mSize, singleN_);
} else {
VFDoDequant<false, Qmm::QuantMode::PERTOKEN_MODE, true, bfloat16_t>(
dequantOutInUbAddr, l0cOutUbAddr, (__ubuf__ DataTypeX2Scale*)x2ScaleUb_.GetPhyAddr(), ptScaleUbAddr,
(__ubuf__ bfloat16_t*)biasUbB16_.GetPhyAddr(), mSize, singleN_);
}
} else if (biasDtype_ == DT_FLOAT16) {
if (dequantTiling_->x2QuantMode == Qmm::QuantMode::PERTENSOR_MODE) {
VFDoDequant<true, Qmm::QuantMode::PERTOKEN_MODE, true, half>(
dequantOutInUbAddr, l0cOutUbAddr, nullptr, ptScaleUbAddr, (__ubuf__ half*)biasUbB16_.GetPhyAddr(),
mSize, singleN_);
} else {
VFDoDequant<false, Qmm::QuantMode::PERTOKEN_MODE, true, half>(
dequantOutInUbAddr, l0cOutUbAddr, (__ubuf__ DataTypeX2Scale*)x2ScaleUb_.GetPhyAddr(), ptScaleUbAddr,
(__ubuf__ half*)biasUbB16_.GetPhyAddr(), mSize, singleN_);
}
}
}
}
QMM_BLOCK_EPILOGUE_DEQUANT_CLASS_LOCAL_PARAMS
__aicore__ inline void BlockEpilogueDequant<QMM_BLOCK_EPILOGUE_DEQUANT_FUNC_LOCAL_PARAMS>::VFDoDequantWithX1Pertensor(
__ubuf__ DataTypeOut* dequantOutInUbAddr, __ubuf__ DataTypeIn* l0cOutUbAddr, uint16_t mSize)
{
VFDoDequant<false, Qmm::QuantMode::PERTENSOR_MODE, false, float>(dequantOutInUbAddr, l0cOutUbAddr,
(__ubuf__ DataTypeX2Scale*)x2ScaleUb_.GetPhyAddr(),
nullptr, nullptr, mSize, singleN_);
}
QMM_BLOCK_EPILOGUE_DEQUANT_CLASS_LOCAL_PARAMS
__aicore__ inline void BlockEpilogueDequant<QMM_BLOCK_EPILOGUE_DEQUANT_FUNC_LOCAL_PARAMS>::VFDoDequantWithoutX1Scale(
__ubuf__ DataTypeOut* dequantOutInUbAddr, __ubuf__ DataTypeIn* l0cOutUbAddr, uint16_t mSize)
{
if (!isBiasEpilogue_) {
VFDoDequant<false, Qmm::QuantMode::DEFAULT, false, float>(dequantOutInUbAddr, l0cOutUbAddr,
(__ubuf__ DataTypeX2Scale*)x2ScaleUb_.GetPhyAddr(),
nullptr, nullptr, mSize, singleN_);
} else {
if (biasDtype_ == DT_FLOAT) {
if (dequantTiling_->x2QuantMode == Qmm::QuantMode::PERTENSOR_MODE) {
VFDoDequant<true, Qmm::QuantMode::DEFAULT, true, float>(
dequantOutInUbAddr, l0cOutUbAddr, nullptr, nullptr, (__ubuf__ float*)biasUbFloat_.GetPhyAddr(),
mSize, singleN_);
} else {
VFDoDequant<false, Qmm::QuantMode::DEFAULT, true, float>(
dequantOutInUbAddr, l0cOutUbAddr, (__ubuf__ DataTypeX2Scale*)x2ScaleUb_.GetPhyAddr(), nullptr,
(__ubuf__ float*)biasUbFloat_.GetPhyAddr(), mSize, singleN_);
}
} else if (biasDtype_ == DT_BF16) {
if (dequantTiling_->x2QuantMode == Qmm::QuantMode::PERTENSOR_MODE) {
VFDoDequant<true, Qmm::QuantMode::DEFAULT, true, bfloat16_t>(
dequantOutInUbAddr, l0cOutUbAddr, nullptr, nullptr, (__ubuf__ bfloat16_t*)biasUbB16_.GetPhyAddr(),
mSize, singleN_);
} else {
VFDoDequant<false, Qmm::QuantMode::DEFAULT, true, bfloat16_t>(
dequantOutInUbAddr, l0cOutUbAddr, (__ubuf__ DataTypeX2Scale*)x2ScaleUb_.GetPhyAddr(), nullptr,
(__ubuf__ bfloat16_t*)biasUbB16_.GetPhyAddr(), mSize, singleN_);
}
}
}
}
QMM_BLOCK_EPILOGUE_DEQUANT_CLASS_LOCAL_PARAMS
template <bool isPertensor, Qmm::QuantMode x1QuantMode, bool isBiasEpilogue, class BiasDtype>
__aicore__ inline void BlockEpilogueDequant<QMM_BLOCK_EPILOGUE_DEQUANT_FUNC_LOCAL_PARAMS>::VFDoDequant(
__ubuf__ DataTypeOut* dst, __ubuf__ DataTypeIn* l0cOut, __ubuf__ DataTypeX2Scale* scale2,
__ubuf__ DataTypeX1Scale* x1Scale, __ubuf__ BiasDtype* bias, uint16_t mSize, uint16_t nSize)
{
uint32_t eleNumPerVf = AscendC::VECTOR_REG_WIDTH / sizeof(DataTypeIn);
uint32_t nSrcUbAligned = Align(nSize, static_cast<uint16_t>(UB_ALIGN_SIZE / sizeof(DataTypeIn)));
uint32_t nDstUbAligned = Align(nSize, static_cast<uint16_t>(UB_ALIGN_SIZE / sizeof(DataTypeOut)));
uint16_t nLoopCnt = (nSize + eleNumPerVf - 1) / eleNumPerVf;
__VEC_SCOPE__
{
AscendC::MicroAPI::MaskReg maskN4B16 =
AscendC::MicroAPI::CreateMask<bfloat16_t, AscendC::MicroAPI::MaskPattern::ALL>();
for (uint16_t mIdx = 0; mIdx < mSize; mIdx++) {
uint32_t elementNum = nSize;
for (uint16_t vfBlockIdx = 0; vfBlockIdx < nLoopCnt; vfBlockIdx++) {
AscendC::MicroAPI::RegTensor<DataTypeIn> l0cOutReg;
AscendC::MicroAPI::RegTensor<DataTypeX2Scale> scaleReg;
AscendC::MicroAPI::RegTensor<DataTypeX1Scale> perTokenScaleReg;
AscendC::MicroAPI::RegTensor<BiasDtype> biasReg;
AscendC::MicroAPI::RegTensor<float> castSrcOutReg, castScaleReg, castScaleOneReg, mulScaleOutReg,
mulPtScaleOutReg, castBiasReg, castBiasOneReg, addBiasOutReg;
AscendC::MicroAPI::RegTensor<DataTypeOut> castResultOutReg;
AscendC::MicroAPI::MaskReg maskN = AscendC::MicroAPI::UpdateMask<DataTypeIn>(elementNum);
uint32_t l0cOutOffset = mIdx * nSrcUbAligned + vfBlockIdx * eleNumPerVf;
AscendC::MicroAPI::DataCopy(l0cOutReg, l0cOut + l0cOutOffset);
if constexpr (IsSameType<DataTypeIn, int32_t>::value) {
AscendC::MicroAPI::Cast<float, DataTypeIn, ctInt322Fp32>(castSrcOutReg, l0cOutReg, maskN);
} else {
castSrcOutReg = l0cOutReg;
}
if constexpr (isPertensor) {
AscendC::MicroAPI::Muls(mulScaleOutReg, castSrcOutReg, x2ScaleScalar_, maskN);
} else {
AscendC::MicroAPI::DataCopy(scaleReg, scale2 + vfBlockIdx * eleNumPerVf);
if constexpr (!IsSameType<DataTypeX2Scale, float>::value) {
AscendC::MicroAPI::Cast<float, DataTypeX2Scale, ctHalf2Fp32Zero>(castScaleReg, scaleReg, maskN);
AscendC::MicroAPI::Cast<float, DataTypeX2Scale, ctHalf2Fp32One>(castScaleOneReg, scaleReg,
maskN4B16);
AscendC::MicroAPI::Interleave(castScaleReg, castScaleOneReg, castScaleReg, castScaleOneReg);
} else {
castScaleReg = scaleReg;
}
AscendC::MicroAPI::Mul(mulScaleOutReg, castSrcOutReg, castScaleReg, maskN);
}
if constexpr (x1QuantMode == Qmm::QuantMode::PERTENSOR_MODE) {
AscendC::MicroAPI::Muls(mulPtScaleOutReg, mulScaleOutReg, x1ScaleScalar_, maskN);
} else if constexpr (x1QuantMode == Qmm::QuantMode::PERTOKEN_MODE) {
AscendC::MicroAPI::DataCopy<DataTypeX1Scale, AscendC::MicroAPI::LoadDist::DIST_BRC_B32>(
perTokenScaleReg, x1Scale + mIdx);
AscendC::MicroAPI::Mul(mulPtScaleOutReg, mulScaleOutReg, perTokenScaleReg, maskN);
} else {
mulPtScaleOutReg = mulScaleOutReg;
}
if constexpr (isBiasEpilogue) {
AscendC::MicroAPI::DataCopy(biasReg, bias + vfBlockIdx * eleNumPerVf);
if constexpr (IsSameType<BiasDtype, bfloat16_t>::value || IsSameType<BiasDtype, half>::value) {
AscendC::MicroAPI::Cast<float, BiasDtype, ctHalf2Fp32Zero>(castBiasReg, biasReg, maskN);
AscendC::MicroAPI::Cast<float, BiasDtype, ctHalf2Fp32One>(castBiasOneReg, biasReg, maskN4B16);
AscendC::MicroAPI::Interleave(castBiasReg, castBiasOneReg, castBiasReg, castBiasOneReg);
} else {
castBiasReg = biasReg;
}
AscendC::MicroAPI::Add(addBiasOutReg, mulPtScaleOutReg, castBiasReg, maskN);
} else {
addBiasOutReg = mulPtScaleOutReg;
}
if constexpr (!IsSameType<DataTypeOut, float>::value) {
AscendC::MicroAPI::Cast<DataTypeOut, float, ctFp322Half>(castResultOutReg, addBiasOutReg, maskN);
} else {
castResultOutReg = addBiasOutReg;
}
uint32_t dstUbOffset = mIdx * nDstUbAligned + vfBlockIdx * eleNumPerVf;
if constexpr (IsSameType<DataTypeOut, float>::value) {
AscendC::MicroAPI::DataCopy<DataTypeOut, AscendC::MicroAPI::StoreDist::DIST_NORM_B32>(
dst + dstUbOffset, castResultOutReg, maskN);
} else {
AscendC::MicroAPI::DataCopy<DataTypeOut, AscendC::MicroAPI::StoreDist::DIST_PACK_B32>(
dst + dstUbOffset, castResultOutReg, maskN);
}
}
}
}
}
QMM_BLOCK_EPILOGUE_DEQUANT_CLASS_LOCAL_PARAMS
__aicore__ inline void BlockEpilogueDequant<QMM_BLOCK_EPILOGUE_DEQUANT_FUNC_LOCAL_PARAMS>::Run(BlockShape& blockShape,
BlockCoord& blockCoord)
{
singleM_ = Get<M_IDX>(blockShape);
singleN_ = Get<N_IDX>(blockShape);
blockCoord_ = blockCoord;
auto halfSingleM = CeilDiv(singleM_, AscendC::GetTaskRation());
uint64_t singleMInVec = subBlockIdx_ == 1 ? singleM_ - halfSingleM : halfSingleM;
if (singleMInVec == 0) {
return;
}
uint64_t mOffset = subBlockIdx_ * halfSingleM;
CopyDataFromGm2Ub();
uint16_t splitNumOfOut = singleMInVec >= 4 ? 4 : singleMInVec;
auto mSizeForOnce = CeilDiv(singleMInVec, static_cast<uint64_t>(splitNumOfOut));
for (uint16_t i = 0; i < splitNumOfOut; i++) {
uint64_t offsetL0c =
i * mSizeForOnce * Align(singleN_, static_cast<uint64_t>(UB_ALIGN_SIZE / sizeof(DataTypeIn)));
if (i * mSizeForOnce >= singleMInVec) {
break;
}
auto mSize = singleMInVec - i * mSizeForOnce >= mSizeForOnce ? mSizeForOnce : singleMInVec - i * mSizeForOnce;
AscendC::LocalTensor<DataTypeOut>& dequantOutInUB = (dequantOutInUBPingPongID_ == 0) ? dequantOutInUBPing_ : dequantOutInUBPong_;
AscendC::WaitFlag<AscendC::HardEvent::MTE3_V>(dequantOutInUBPingPongID_);
__ubuf__ DataTypeOut* dequantOutInUbAddr = (__ubuf__ DataTypeOut*)dequantOutInUB.GetPhyAddr();
__ubuf__ DataTypeIn* l0cOutUbAddr = (__ubuf__ DataTypeIn*)l0cOutUb_.GetPhyAddr();
l0cOutUbAddr = l0cOutUbAddr + offsetL0c;
switch (dequantTiling_->x1QuantMode) {
case (Qmm::QuantMode::PERTOKEN_MODE): {
uint64_t offsetPtScale = i * mSizeForOnce;
VFDoDequantWithX1Pertoken(dequantOutInUbAddr, l0cOutUbAddr, offsetPtScale, mSize);
break;
}
case (Qmm::QuantMode::PERTENSOR_MODE): {
VFDoDequantWithX1Pertensor(dequantOutInUbAddr, l0cOutUbAddr, mSize);
break;
}
default: {
VFDoDequantWithoutX1Scale(dequantOutInUbAddr, l0cOutUbAddr, mSize);
}
}
AscendC::SetFlag<AscendC::HardEvent::V_MTE3>(dequantOutInUBPingPongID_);
AscendC::WaitFlag<AscendC::HardEvent::V_MTE3>(dequantOutInUBPingPongID_);
CopyDequantResFromUb2Gm(mSize, (mOffset + i * mSizeForOnce) * Get<N_IDX>(problemShape_), dequantOutInUB);
AscendC::SetFlag<AscendC::HardEvent::MTE3_V>(dequantOutInUBPingPongID_);
dequantOutInUBPingPongID_ ^= 1;
}
UbSetFlag();
}
QMM_BLOCK_EPILOGUE_DEQUANT_CLASS_LOCAL_PARAMS
__aicore__ inline auto BlockEpilogueDequant<QMM_BLOCK_EPILOGUE_DEQUANT_FUNC_LOCAL_PARAMS>::GetL0c2UbTensor()
{
return l0cOutUb_;
}
QMM_BLOCK_EPILOGUE_DEQUANT_CLASS_LOCAL_PARAMS
__aicore__ inline void BlockEpilogueDequant<QMM_BLOCK_EPILOGUE_DEQUANT_FUNC_LOCAL_PARAMS>::SetOutL2CacheHint()
{
yGlobal_.SetL2CacheHint(CacheMode::CACHE_MODE_DISABLE);
}
QMM_BLOCK_EPILOGUE_DEQUANT_CLASS_LOCAL_PARAMS
__aicore__ inline void
BlockEpilogueDequant<QMM_BLOCK_EPILOGUE_DEQUANT_FUNC_LOCAL_PARAMS>::operator()(BlockShape& blockShape,
BlockCoord& blockCoord)
{
Run(blockShape, blockCoord);
return;
}
}
}
}
#endif
#endif