* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file mc2_quant_batch_matmul.h
* \brief
*/
#ifndef NEW_MC2_QUANT_BATCH_MATMUL
#define NEW_MC2_QUANT_BATCH_MATMUL
#include "../../3rd/quant_batch_matmul_v3/op_kernel/quant_batch_matmul_v3_base.h"
#include "mc2_quant_batch_matmul_asw_block.h"
#define LOCAL_TEMPLATE_CLASS_PARAMS \
template <class X1Type, class X2Type, class ScaleType, class BiasType, class YType, CubeFormat FormatX1, \
CubeFormat FormatX2, CubeFormat FormatY, bool ATrans, bool BTrans>
#define LOCAL_TEMPLATE_FUNC_PARAMS X1Type, X2Type, ScaleType, BiasType, YType, FormatX1, FormatX2, \
FormatY, ATrans, BTrans
namespace Mc2MatmulV3 {
constexpr uint64_t DEQ_SCALE_MUL = 0xFFFFE000;
using namespace AscendC;
LOCAL_TEMPLATE_CLASS_PARAMS
class Mc2QuantBatchMatmulASWKernel {
public:
__aicore__ inline Mc2QuantBatchMatmulASWKernel() {}
__aicore__ inline void Init(GM_ADDR aGM, GM_ADDR bGM, GM_ADDR bias, GM_ADDR scale, GM_ADDR perTokenScale,
GM_ADDR cGM, GM_ADDR workSpace, const void *tilingData, TPipe *que,
const Mc2Tiling::RCSTiling& cfg, bool isTail, bool isGather, uint64_t preCoreNum);
__aicore__ inline void UpdateSlice(uint32_t idx, bool isTail);
__aicore__ inline void Process(bool isLast = true);
__aicore__ inline uint64_t GetPreCoreNum();
__aicore__ inline void MmEnd();
protected:
__aicore__ inline void ComputeBasicOptiLoop(bool isLast = false);
__aicore__ inline void SetMMParaAndCompute();
__aicore__ inline void UpdateGlobalAddr(GM_ADDR aGM, GM_ADDR bGM, GM_ADDR bias, GM_ADDR perTokenScale,
GM_ADDR scale, GM_ADDR cGM, GM_ADDR workSpace);
uint32_t blockIdx_;
const DequantBmm::Mc2QuantBatchMatmulV3TilingDataParams *quantBmmTilingData_;
GlobalTensor<X1Type> aGlobal_;
GlobalTensor<X2Type> bGlobal_;
GlobalTensor<YType> cGlobal_;
GlobalTensor<BiasType> biasGlobal_;
Mc2MatmulV3::QuantBatchMatmulAswBlock block_;
GlobalTensor<fp8_e8m0_t> scaleAGlobal_;
GlobalTensor<fp8_e8m0_t> scaleBGlobal_;
GlobalTensor<uint64_t> scaleGlobal_;
using AType = typename AscendC::Conditional<
DequantBmm::IsMxType<ScaleType>(),
matmul::MatmulTypeWithScale<AscendC::TPosition::GM, AscendC::TPosition::GM, FormatX1, X1Type, ATrans>,
matmul::MatmulType<AscendC::TPosition::GM, FormatX1, X1Type, ATrans>>::type;
using BType = typename AscendC::Conditional<
DequantBmm::IsMxType<ScaleType>(),
matmul::MatmulTypeWithScale<AscendC::TPosition::GM, AscendC::TPosition::GM, FormatX2, X2Type, BTrans>,
matmul::MatmulType<AscendC::TPosition::GM, FormatX2, X2Type, BTrans>>::type;
using CType = matmul::MatmulType<AscendC::TPosition::GM, FormatY, YType>;
using BiasMatmulType = matmul::MatmulType<AscendC::TPosition::GM, CubeFormat::ND, BiasType>;
using MmType = typename AscendC::Conditional<
DequantBmm::IsMxType<ScaleType>(),
matmul::MatmulImpl<AType, BType, CType, BiasMatmulType, MM_CFG_NO_PRELOAD_OPEN_UNIT_FLAG,
MatmulCallBackFunc<nullptr, nullptr, nullptr>, AscendC::Impl::Detail::MatmulWithScalePolicy>,
matmul::MatmulImpl<AType, BType, CType, BiasMatmulType, MM_CFG_NO_PRELOAD_OPEN_UNIT_FLAG>>::type;
MmType mm_;
};
LOCAL_TEMPLATE_CLASS_PARAMS
__aicore__ inline void Mc2QuantBatchMatmulASWKernel<LOCAL_TEMPLATE_FUNC_PARAMS>::Init(GM_ADDR aGM, GM_ADDR bGM,
GM_ADDR bias,
GM_ADDR scale,
GM_ADDR perTokenScale,
GM_ADDR cGM,
GM_ADDR workSpace,
const void *tilingData,
TPipe *que,
const Mc2Tiling::RCSTiling& cfg,
bool isTail, bool isGather,
uint64_t preCoreNum)
{
if ASCEND_IS_AIV {
return;
}
quantBmmTilingData_ = static_cast<const DequantBmm::Mc2QuantBatchMatmulV3TilingDataParams *>(tilingData);
blockIdx_ = GetBlockIdx();
block_.Init(quantBmmTilingData_, blockIdx_);
block_.InitForMC2(cfg, isTail, isGather, preCoreNum);
UpdateGlobalAddr(aGM, bGM, bias, scale, perTokenScale, cGM, workSpace);
mm_.SetSubBlockIdx(0);
mm_.Init(&quantBmmTilingData_->matmulTiling, que);
}
LOCAL_TEMPLATE_CLASS_PARAMS
__aicore__ inline uint64_t Mc2QuantBatchMatmulASWKernel<LOCAL_TEMPLATE_FUNC_PARAMS>::GetPreCoreNum()
{
return block_.preCoreNum_;
}
LOCAL_TEMPLATE_CLASS_PARAMS
__aicore__ inline void Mc2QuantBatchMatmulASWKernel<LOCAL_TEMPLATE_FUNC_PARAMS>::UpdateGlobalAddr(GM_ADDR aGM,
GM_ADDR bGM,
GM_ADDR bias,
GM_ADDR scale,
GM_ADDR perTokenScale,
GM_ADDR cGM,
GM_ADDR workSpace)
{
if constexpr (DequantBmm::IsMxType<ScaleType>()) {
scaleAGlobal_.SetGlobalBuffer((__gm__ fp8_e8m0_t *)perTokenScale);
scaleBGlobal_.SetGlobalBuffer((__gm__ fp8_e8m0_t *)scale);
} else {
if (static_cast<bool>(quantBmmTilingData_->params.isDoubleScale)) {
float deqScale = (*((__gm__ float *)scale)) * (*((__gm__ float *)perTokenScale));
uint32_t uint32Scale = *(reinterpret_cast<uint32_t *>(&deqScale));
block_.offset_.scaleScalar = uint32Scale & DEQ_SCALE_MUL;
} else if (static_cast<bool>(quantBmmTilingData_->params.isPerTensor)) {
if constexpr (!IsSameType<ScaleType, uint64_t>::value) {
uint32_t uint32Scale = 0;
if constexpr (IsSameType<ScaleType, bfloat16_t>::value) {
uint16_t uint16Scale = *((__gm__ uint16_t *)scale);
uint32Scale = uint16Scale << 16;
} else {
uint32Scale = *((__gm__ uint32_t *)scale);
}
block_.offset_.scaleScalar = uint32Scale & DEQ_SCALE_MUL;
} else {
block_.offset_.scaleScalar = *((__gm__ uint64_t *)scale);
}
} else {
scaleGlobal_.SetGlobalBuffer((__gm__ uint64_t *)scale);
}
}
aGlobal_.SetGlobalBuffer((__gm__ X1Type*)aGM);
bGlobal_.SetGlobalBuffer((__gm__ X2Type*)bGM);
if (static_cast<bool>(quantBmmTilingData_->matmulTiling.isBias)) {
biasGlobal_.SetGlobalBuffer((__gm__ BiasType*)bias);
}
cGlobal_.SetGlobalBuffer((__gm__ YType*)cGM);
}
LOCAL_TEMPLATE_CLASS_PARAMS
__aicore__ inline void Mc2QuantBatchMatmulASWKernel<LOCAL_TEMPLATE_FUNC_PARAMS>::UpdateSlice(
uint32_t idx, bool isTail)
{
this->block_.template UpdateOffset<ScaleType>(idx, isTail);
}
LOCAL_TEMPLATE_CLASS_PARAMS
__aicore__ inline void Mc2QuantBatchMatmulASWKernel<LOCAL_TEMPLATE_FUNC_PARAMS>::Process(bool isLast)
{
if ASCEND_IS_AIV {
return;
}
if (quantBmmTilingData_->params.batchC == 1) {
block_.offset_.batchCOffset = 0;
block_.offset_.batchAOffset = 0;
block_.offset_.batchBOffset = 0;
ComputeBasicOptiLoop(isLast);
}
}
LOCAL_TEMPLATE_CLASS_PARAMS
__aicore__ inline void Mc2QuantBatchMatmulASWKernel<LOCAL_TEMPLATE_FUNC_PARAMS>::ComputeBasicOptiLoop(bool isLast)
{
block_.Update();
for (uint64_t roundIndex = 0; roundIndex < block_.params_.round; ++roundIndex) {
if (roundIndex + 1 == this->block_.params_.round && isLast) {
uint64_t lastRoundTileCnt = (block_.params_.totalCnt - block_.tilingData_->matmulTiling.usedCoreNum *
roundIndex) * block_.params_.totalTailTile;
block_.doTailTile_ = (lastRoundTileCnt <= block_.tilingData_->matmulTiling.usedCoreNum);
}
uint64_t newBlockIdx = block_.doTailTile_ ?
(GetBlockIdx() / this->block_.params_.totalTailTile) : GetBlockIdx();
if ((roundIndex == 0) && (newBlockIdx < this->block_.preCoreNum_)) {
continue;
}
if (!block_.UpdateBasicIndex(roundIndex, newBlockIdx)) {
continue;
}
block_.UpdateBlockParams(roundIndex, isLast);
if (block_.params_.singleCoreM <= 0 || block_.params_.singleCoreN <= 0) {
continue;
}
mm_.SetSingleShape(block_.params_.singleCoreM, block_.params_.singleCoreN,
quantBmmTilingData_->matmulTiling.Ka);
block_.template CalcGMOffset<ATrans, BTrans, X1Type, ScaleType, FormatX2>();
SetMMParaAndCompute();
}
this->block_.preCoreNum_ = this->block_.params_.totalCnt % (this->block_.tilingData_->matmulTiling.usedCoreNum);
}
LOCAL_TEMPLATE_CLASS_PARAMS
__aicore__ inline void Mc2QuantBatchMatmulASWKernel<LOCAL_TEMPLATE_FUNC_PARAMS>::SetMMParaAndCompute()
{
if constexpr (DequantBmm::IsMxType<ScaleType>()) {
mm_.SetTensorScaleA(scaleAGlobal_[block_.offset_.offsetPerTokenScale], ATrans);
mm_.SetTensorScaleB(scaleBGlobal_[block_.offset_.offsetScale], BTrans);
} else {
if (static_cast<bool>(quantBmmTilingData_->params.isPerTensor) ||
static_cast<bool>(quantBmmTilingData_->params.isDoubleScale)) {
mm_.SetQuantScalar(block_.offset_.scaleScalar);
} else {
mm_.SetQuantVector(scaleGlobal_[block_.offset_.offsetScale]);
}
}
if (quantBmmTilingData_->matmulTiling.isBias) {
mm_.SetBias(biasGlobal_[block_.offset_.offsetBias]);
}
mm_.SetTensorA(aGlobal_[block_.offset_.offsetA], ATrans);
mm_.SetTensorB(bGlobal_[block_.offset_.offsetB], BTrans);
mm_.Iterate();
mm_.GetTensorC(cGlobal_[block_.offset_.offsetC]);
}
LOCAL_TEMPLATE_CLASS_PARAMS
__aicore__ inline void Mc2QuantBatchMatmulASWKernel<LOCAL_TEMPLATE_FUNC_PARAMS>::MmEnd()
{
mm_.End();
}
}
#endif