* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file mla_prolog_tiling.cpp
* \brief
*/
#include <numeric>
#include <functional>
#include <algorithm>
#include <graph/utils/type_utils.h>
#include "log/log.h"
#include "err/ops_err.h"
#include "register/op_def_registry.h"
#include "mla_prolog_tiling_check.h"
#include "mla_prolog_tiling.h"
using namespace ge;
namespace optiling {
const std::unordered_map<ge::DataType, uint32_t> DTYPE_TO_SIZE {
{ge::DT_BF16, 2},
{ge::DT_FLOAT16, 2},
{ge::DT_INT8, 1},
{ge::DT_FLOAT8_E4M3FN, 1},
{ge::DT_FLOAT8_E8M0, 1},
{ge::DT_HIFLOAT8, 1},
{ge::DT_INT32, 4},
{ge::DT_FLOAT, 4}};
const std::unordered_map<ge::DataType, matmul_tiling::DataType> GE_TO_MM_DTYPE {
{ge::DT_FLOAT16, matmul_tiling::DataType::DT_FLOAT16},
{ge::DT_BF16, matmul_tiling::DataType::DT_BF16},
{ge::DT_INT8, matmul_tiling::DataType::DT_INT8},
{ge::DT_INT4, matmul_tiling::DataType::DT_INT4},
{ge::DT_FLOAT, matmul_tiling::DataType::DT_FLOAT},
{ge::DT_FLOAT8_E4M3FN, matmul_tiling::DataType::DT_FLOAT8_E4M3FN},
{ge::DT_FLOAT8_E8M0, matmul_tiling::DataType::DT_FLOAT8_E8M0},
{ge::DT_HIFLOAT8, matmul_tiling::DataType::DT_HIFLOAT8}};
template <typename T>
inline auto CeilDiv(T a, T b) -> T
{
if (b == 0) {
return b;
}
return (a + b - 1) / b;
}
template <typename T>
inline auto Align(T num, T rnd) -> T
{
return (((rnd) == 0) ? 0 : (((num) + (rnd) - 1) / (rnd) * (rnd)));
}
NpuArch MlaPrologTiling::GetCurNpuArch() const
{
auto ascendcPlatform = platform_ascendc::PlatformAscendC(context_->platformInfo);
NpuArch npuArch = ascendcPlatform.GetCurNpuArch();
return npuArch;
}
ge::graphStatus MlaPrologTiling::GetNpuInfo()
{
OP_CHECK_IF(context_->platformInfo == nullptr,
OPS_REPORT_VECTOR_INNER_ERR(context_->opName, "GetPlatformInfo is nullptr."), return ge::GRAPH_FAILED);
auto ascendcPlatform = platform_ascendc::PlatformAscendC(context_->platformInfo);
libapiSize_ = ascendcPlatform.GetLibApiWorkSpaceSize();
ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ubSize_);
ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::L1, l1Size_);
ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::L0_C, l0cSize_);
ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::L0_B, l0bSize_);
aivNum_ = ascendcPlatform.GetCoreNumAiv();
aicNum_ = ascendcPlatform.GetCoreNumAic();
OP_CHECK_IF(aicNum_ == 0 || aivNum_ == 0,
OPS_REPORT_VECTOR_INNER_ERR(context_->opName, "num of core obtained is 0."), return GRAPH_FAILED);
OP_CHECK_IF((aicNum_ != aivNum_) && (aicNum_ * 2 != aivNum_),
OPS_REPORT_VECTOR_INNER_ERR(context_->opName, "aicNum(%u):aivNum(%u) only support 1:1 or 1:2", aicNum_, aivNum_), return GRAPH_FAILED);
return ge::GRAPH_SUCCESS;
}
QUANT_MODE MlaPrologTiling::GetQuantizationModeV3() const
{
if (*(context_->weightQuantMode) == static_cast<int>(WEIGHT_QUANT_MODE::NO_QUANT)) {
if (*(context_->kvQuantMode) == static_cast<int>(KV_QUANT_MODE::NO_QUANT)) {
return QUANT_MODE::NO_QUANT;
} else {
OP_LOGE(context_->opName, "When weightQuantMode == 0, kvQuantMode must be within {0}, actually is %d.", *(context_->kvQuantMode));
}
} else if (*(context_->weightQuantMode) == static_cast<int>(WEIGHT_QUANT_MODE::PARTIAL_QUANT)) {
if (*(context_->kvQuantMode) == static_cast<int>(KV_QUANT_MODE::NO_QUANT)) {
return QUANT_MODE::PARTIAL_QUANT_KV_NO_QUANT;
} else if (*(context_->kvQuantMode) == static_cast<int>(KV_QUANT_MODE::PER_CHANNEL)) {
return QUANT_MODE::PARTIAL_QUANT_KV_QUANT_PER_CHANNEL;
} else if (*(context_->kvQuantMode) == static_cast<int>(KV_QUANT_MODE::PER_TILE)) {
return QUANT_MODE::PARTIAL_QUANT_KV_QUANT_PER_TILE;
} else {
OP_LOGE(context_->opName, "When weightQuantMode == 1, kvQuantMode must be within {0, 2, 3}, actually is %d.", *(context_->kvQuantMode));
}
} else if (*(context_->weightQuantMode) == static_cast<int>(WEIGHT_QUANT_MODE::FULL_QUANT)) {
if (*(context_->kvQuantMode) == static_cast<int>(KV_QUANT_MODE::NO_QUANT)) {
return QUANT_MODE::FULL_QUANT_KV_NO_QUANT;
} else if (*(context_->kvQuantMode) == static_cast<int>(KV_QUANT_MODE::PER_TENSOR)) {
return QUANT_MODE::FULL_QUANT_KV_QUANT_PER_TENSOR;
} else if (*(context_->kvQuantMode) == static_cast<int>(KV_QUANT_MODE::PER_TILE)) {
return QUANT_MODE::FULL_QUANT_KV_QUANT_PER_TILE;
} else {
OP_LOGE(context_->opName, "When weightQuantMode == 2, kvQuantMode must be within {0, 1, 3}, actually is %d.", *(context_->kvQuantMode));
}
} else {
OP_LOGE(context_->opName, "WeightQuantMode must be within {0, 1, 2}, actually is %d.", *(context_->weightQuantMode));
}
return QUANT_MODE::ERROR_MODE;
}
QUANT_MODE MlaPrologTiling::GetQuantizationModeV3Mxfp8() const
{
if (*(context_->weightQuantMode) == static_cast<int>(WEIGHT_QUANT_MODE::NO_QUANT)) {
if (*(context_->kvQuantMode) == static_cast<int>(KV_QUANT_MODE::NO_QUANT)) {
return QUANT_MODE::NO_QUANT;
} else {
OP_LOGE(context_->opName, "When weightQuantMode == 0, kvQuantMode must be within {0}, actually is %d.", *(context_->kvQuantMode));
}
} else if (*(context_->weightQuantMode) == static_cast<int>(WEIGHT_QUANT_MODE::PARTIAL_QUANT)) {
if (*(context_->kvQuantMode) == static_cast<int>(KV_QUANT_MODE::NO_QUANT)) {
return QUANT_MODE::PARTIAL_QUANT_KV_NO_QUANT;
} else if (*(context_->kvQuantMode) == static_cast<int>(KV_QUANT_MODE::PER_CHANNEL)) {
return QUANT_MODE::PARTIAL_QUANT_KV_QUANT_PER_CHANNEL;
} else if (*(context_->kvQuantMode) == static_cast<int>(KV_QUANT_MODE::PER_TILE)) {
return QUANT_MODE::PARTIAL_QUANT_KV_QUANT_PER_TILE;
} else {
OP_LOGE(context_->opName,
"When weightQuantMode == 1, kvQuantMode must be within {0, 2, 3}, actually is %d.",
*(context_->kvQuantMode));
}
} else if (*(context_->weightQuantMode) == static_cast<int>(WEIGHT_QUANT_MODE::FULL_QUANT)) {
if (*(context_->kvQuantMode) == static_cast<int>(KV_QUANT_MODE::NO_QUANT)) {
return QUANT_MODE::FULL_QUANT_KV_NO_QUANT;
} else if (*(context_->kvQuantMode) == static_cast<int>(KV_QUANT_MODE::PER_TENSOR)) {
return QUANT_MODE::FULL_QUANT_KV_QUANT_PER_TENSOR;
} else if (*(context_->kvQuantMode) == static_cast<int>(KV_QUANT_MODE::PER_TILE)) {
return QUANT_MODE::FULL_QUANT_KV_QUANT_PER_TILE;
} else {
OP_LOGE(context_->opName,
"When weightQuantMode == 2, kvQuantMode must be within {0, 1, 3}, actually is %d.",
*(context_->kvQuantMode));
}
} else if (*(context_->weightQuantMode) == static_cast<int>(WEIGHT_QUANT_MODE::FP8_FULL_QUANT)) {
if (*(context_->kvQuantMode) == static_cast<int>(KV_QUANT_MODE::NO_QUANT)) {
return QUANT_MODE::FP8_FULL_QUANT_KV_NO_QUANT;
} else if (*(context_->kvQuantMode) == static_cast<int>(KV_QUANT_MODE::PER_TENSOR)) {
return QUANT_MODE::FP8_FULL_QUANT_KV_QUANT_PER_TENSOR;
} else if (*(context_->kvQuantMode) == static_cast<int>(KV_QUANT_MODE::PER_TILE)) {
return QUANT_MODE::FP8_FULL_QUANT_KV_QUANT_PER_TILE;
} else {
OP_LOGE(context_->opName,
"When weightQuantMode == 4, kvQuantMode must be within {0, 1, 3}, actually is %d.",
*(context_->kvQuantMode));
}
} else if (*(context_->weightQuantMode) == static_cast<int>(WEIGHT_QUANT_MODE::HIF8_FULL_QUANT)) {
if (*(context_->kvQuantMode) == static_cast<int>(KV_QUANT_MODE::NO_QUANT)) {
return QUANT_MODE::HIF8_FULL_QUANT_KV_NO_QUANT;
} else if (*(context_->kvQuantMode) == static_cast<int>(KV_QUANT_MODE::PER_TENSOR)) {
return QUANT_MODE::HIF8_FULL_QUANT_KV_QUANT_PER_TENSOR;
} else if (*(context_->kvQuantMode) == static_cast<int>(KV_QUANT_MODE::PER_TILE)) {
return QUANT_MODE::HIF8_FULL_QUANT_KV_QUANT_PER_TILE;
} else {
OP_LOGE(context_->opName,
"When weightQuantMode == 5, kvQuantMode must be within {0, 1, 3}, actually is %d.",
*(context_->kvQuantMode));
}
} else if (*(context_->weightQuantMode) == static_cast<int>(WEIGHT_QUANT_MODE::MXFP8_FULL_QUANT)) {
if (*(context_->kvQuantMode) == static_cast<int>(KV_QUANT_MODE::NO_QUANT)) {
return QUANT_MODE::MXFP8_FULL_QUANT_KV_NO_QUANT;
} else if (*(context_->kvQuantMode) == static_cast<int>(KV_QUANT_MODE::PER_TENSOR)) {
return QUANT_MODE::MXFP8_FULL_QUANT_KV_QUANT_PER_TENSOR;
} else if (*(context_->kvQuantMode) == static_cast<int>(KV_QUANT_MODE::PER_TILE)) {
return QUANT_MODE::MXFP8_FULL_QUANT_KV_QUANT_PER_TILE;
} else {
OP_LOGE(context_->opName, "When weightQuantMode == 3, kvQuantMode must be within {0, 1, 3}, actually is %d.", *(context_->kvQuantMode));
}
} else {
OP_LOGE(context_->opName,
"weightQuantMode must be within {0, 1, 2, 3, 4, 5}, actually is %d.",
*(context_->weightQuantMode));
}
return QUANT_MODE::ERROR_MODE;
}
QUANT_MODE MlaPrologTiling::GetQuantizationMode() const
{
if (std::strncmp(context_->opType, V3_OP_NAME, OP_NAME_LEN) == 0) {
if (GetCurNpuArch() == NpuArch::DAV_3510){
return GetQuantizationModeV3Mxfp8();
} else {
return GetQuantizationModeV3();
}
} else {
if (context_->tokenX.desc->GetDataType() == ge::DT_INT8) {
if (context_->kvCache.desc->GetDataType() == ge::DT_INT8) {
return QUANT_MODE::FULL_QUANT_KV_QUANT_PER_TENSOR;
} else {
return QUANT_MODE::FULL_QUANT_KV_NO_QUANT;
}
}
if (context_->tokenX.desc->GetDataType() == ge::DT_BF16 && context_->weightUqQr.desc->GetDataType() == ge::DT_INT8) {
if (context_->kvCache.desc->GetDataType() == ge::DT_INT8) {
return QUANT_MODE::PARTIAL_QUANT_KV_QUANT_PER_CHANNEL;
} else {
return QUANT_MODE::PARTIAL_QUANT_KV_NO_QUANT;
}
}
return QUANT_MODE::NO_QUANT;
}
return QUANT_MODE::ERROR_MODE;
}
ge::graphStatus MlaPrologTiling::SetShapeInfo()
{
if (context_->tokenX.shape->GetStorageShape().GetDimNum() == MLA_PROLOG_DIM_NUM_3) {
baseShapeInfo_.bSize = context_->tokenX.shape->GetStorageShape().GetDim(MLA_PROLOG_DIM_INDEX_0);
baseShapeInfo_.s1Size = context_->tokenX.shape->GetStorageShape().GetDim(MLA_PROLOG_DIM_INDEX_1);
baseShapeInfo_.heSize = context_->tokenX.shape->GetStorageShape().GetDim(MLA_PROLOG_DIM_INDEX_2);
baseShapeInfo_.tSize = baseShapeInfo_.bSize * baseShapeInfo_.s1Size;
} else {
baseShapeInfo_.tSize = context_->tokenX.shape->GetStorageShape().GetDim(MLA_PROLOG_DIM_INDEX_0);
baseShapeInfo_.heSize = context_->tokenX.shape->GetStorageShape().GetDim(MLA_PROLOG_DIM_INDEX_1);
}
if (context_->weightDq.shape->GetStorageShape().GetDimNum() == MLA_PROLOG_DIM_NUM_2) {
baseShapeInfo_.hcqSize = context_->weightDq.shape->GetStorageShape().GetDim(MLA_PROLOG_DIM_INDEX_1);
} else {
uint32_t weightDqAxisSize_ = 32U / ge::GetSizeByDataType(context_->weightDq.desc->GetDataType());
baseShapeInfo_.hcqSize =
weightDqAxisSize_ * context_->weightDq.shape->GetStorageShape().GetDim(MLA_PROLOG_DIM_INDEX_0);
}
baseShapeInfo_.nSize = context_->weightUk.shape->GetStorageShape().GetDim(MLA_PROLOG_DIM_INDEX_0);
baseShapeInfo_.drSize =
context_->ropeCos.shape->GetStorageShape().GetDim(context_->ropeCos.shape->GetStorageShape().GetDimNum() - 1);
baseShapeInfo_.dSize = context_->weightUk.shape->GetStorageShape().GetDim(MLA_PROLOG_DIM_INDEX_1);
baseShapeInfo_.headSizeQc = baseShapeInfo_.dSize * baseShapeInfo_.nSize;
baseShapeInfo_.headSizeQr = baseShapeInfo_.drSize * baseShapeInfo_.nSize;
baseShapeInfo_.headSizeUqQr = baseShapeInfo_.headSizeQc + baseShapeInfo_.headSizeQr;
if (context_->kvCache.shape->GetStorageShape().GetDimNum() == MLA_PROLOG_DIM_NUM_3) {
baseShapeInfo_.blockNum = context_->kvCache.shape->GetStorageShape().GetDim(MLA_PROLOG_DIM_INDEX_0);
baseShapeInfo_.nkvSize = context_->kvCache.shape->GetStorageShape().GetDim(MLA_PROLOG_DIM_INDEX_1);
baseShapeInfo_.dtileSize = context_->kvCache.shape->GetStorageShape().GetDim(MLA_PROLOG_DIM_INDEX_2);
} else {
baseShapeInfo_.blockNum = context_->kvCache.shape->GetStorageShape().GetDim(MLA_PROLOG_DIM_INDEX_0);
baseShapeInfo_.blockSize = context_->kvCache.shape->GetStorageShape().GetDim(MLA_PROLOG_DIM_INDEX_1);
baseShapeInfo_.nkvSize = context_->kvCache.shape->GetStorageShape().GetDim(MLA_PROLOG_DIM_INDEX_2);
baseShapeInfo_.dtileSize = context_->kvCache.shape->GetStorageShape().GetDim(MLA_PROLOG_DIM_INDEX_3);
}
if (context_->weightDkvKr.shape->GetStorageShape().GetDimNum() == MLA_PROLOG_DIM_NUM_4) {
baseShapeInfo_.hckvSize =
context_->weightDkvKr.shape->GetStorageShape().GetDim(MLA_PROLOG_DIM_INDEX_0) *
context_->weightDkvKr.shape->GetStorageShape().GetDim(MLA_PROLOG_DIM_INDEX_3) -
baseShapeInfo_.drSize;
} else {
baseShapeInfo_.hckvSize =
context_->weightDkvKr.shape->GetStorageShape().GetDim(MLA_PROLOG_DIM_INDEX_1) - baseShapeInfo_.drSize;
}
baseShapeInfo_.s2Size = baseShapeInfo_.nkvSize;
return ge::GRAPH_SUCCESS;
}
ge::graphStatus MlaPrologTiling::SetScenarioInfo()
{
scenarioInfo_.isV1Flag_ = (std::strncmp(context_->opType, V1_OP_NAME, OP_NAME_LEN) == 0);
scenarioInfo_.batchSeqFusedFlag_ = context_->tokenX.shape->GetStorageShape().GetDimNum() == MLA_PROLOG_DIM_NUM_2;
scenarioInfo_.quantMode_ = GetQuantizationMode();
if (scenarioInfo_.quantMode_ == QUANT_MODE::ERROR_MODE) {
return ge::GRAPH_FAILED;
}
if (std::strncmp(context_->cacheMode, CACHE_MODE_BSND, CACHE_MODE_LEN) == 0) {
scenarioInfo_.cacheMode_ = CACHE_MODE::BSND;
} else if (std::strncmp(context_->cacheMode, CACHE_MODE_TND, CACHE_MODE_LEN) == 0) {
scenarioInfo_.cacheMode_ = CACHE_MODE::TND;
} else if (std::strncmp(context_->cacheMode, CACHE_MODE_PA_BSND, CACHE_MODE_LEN) == 0) {
scenarioInfo_.cacheMode_ = CACHE_MODE::PA_BSND;
} else if (std::strncmp(context_->cacheMode, CACHE_MODE_PA_NZ, CACHE_MODE_LEN) == 0) {
scenarioInfo_.cacheMode_ = CACHE_MODE::PA_NZ;
} else if (std::strncmp(context_->cacheMode, CACHE_MODE_PA_BLK_BSND, CACHE_MODE_LEN) == 0) {
scenarioInfo_.cacheMode_ = CACHE_MODE::PA_BLK_BSND;
} else {
scenarioInfo_.cacheMode_ = CACHE_MODE::PA_BLK_NZ;
}
if ((scenarioInfo_.batchSeqFusedFlag_ && baseShapeInfo_.tSize == 0U) ||
(!scenarioInfo_.batchSeqFusedFlag_ && (baseShapeInfo_.bSize * baseShapeInfo_.s1Size == 0U))) {
scenarioInfo_.emptyTensorMode_ = EMPTY_TENSOR_MODE::EMPTY_QUERY;
} else if (baseShapeInfo_.blockNum == 0U) {
scenarioInfo_.emptyTensorMode_ = EMPTY_TENSOR_MODE::EMPTY_CACHE;
} else {
scenarioInfo_.emptyTensorMode_ = EMPTY_TENSOR_MODE::NON_EMPTY;
}
if (scenarioInfo_.batchSeqFusedFlag_ &&
(scenarioInfo_.cacheMode_ == CACHE_MODE::PA_BLK_BSND || scenarioInfo_.cacheMode_ == CACHE_MODE::PA_BLK_NZ)) {
OP_CHECK_IF(context_->actualSeqLen.shape == nullptr,
OP_LOGE(context_->opName,
"When cacheMode in {PA_BLK_BSND, PA_BLK_NZ} and tokenX shape dim num is 2,"
"actualSeqLen should not be null."), return GRAPH_FAILED);
baseShapeInfo_.bSize = context_->actualSeqLen.shape->GetStorageShape().GetDim(MLA_PROLOG_DIM_INDEX_0);
scenarioInfo_.actualSeqMode_ = ACTUAL_SEQ_MODE::EN_Q_LEN;
} else {
scenarioInfo_.actualSeqMode_ = ACTUAL_SEQ_MODE::DISABLED;
}
uint32_t cvRatio = aivNum_ / aicNum_;
scenarioInfo_.splitMFlag_ = 0U;
if ((scenarioInfo_.quantMode_ == QUANT_MODE::MXFP8_FULL_QUANT_KV_NO_QUANT ||
scenarioInfo_.quantMode_ == QUANT_MODE::MXFP8_FULL_QUANT_KV_QUANT_PER_TENSOR ||
scenarioInfo_.quantMode_ == QUANT_MODE::MXFP8_FULL_QUANT_KV_QUANT_PER_TILE) &&
baseShapeInfo_.tSize >= 8192 && cvRatio == 2) {
scenarioInfo_.splitMFlag_ = 1U;
}
return ge::GRAPH_SUCCESS;
}
ge::graphStatus MlaPrologTiling::SetAttrInfo()
{
reciprocalCq_ = 1.0f / baseShapeInfo_.hcqSize;
epsilonCq_ = *(context_->rmsNormEspilonCq);
reciprocalCkv_ = 1.0f / baseShapeInfo_.hckvSize;
epsilonCkv_ = *(context_->rmsNormEspilonCkv);
if (std::strncmp(context_->opType, V3_OP_NAME, OP_NAME_LEN) == 0) {
queryNormFlag_ = *(context_->queryNormFlag);
weightQuantMode_ = static_cast<WEIGHT_QUANT_MODE>(*(context_->weightQuantMode));
kvQuantMode_ = static_cast<KV_QUANT_MODE>(*(context_->kvQuantMode));
queryQuantMode_ = static_cast<QUERY_QUANT_MODE>(*(context_->queryQuantMode));
ckvkrRepoMode_ = static_cast<CKVKR_REPO_MODE>(*(context_->ckvkrRepoMode));
quantSacleRepoMode_ = static_cast<QUANT_SCALE_REPO_MODE>(*(context_->quantScaleRepoMode));
tileSize_ = static_cast<uint32_t>(*(context_->tileSize));
qcQrScale_ = *(context_->qcQrScale);
kcScale_ = *(context_->kcScale);
}
return ge::GRAPH_SUCCESS;
}
bool MlaPrologTiling::GetMatmulType(ge::DataType getype, matmul_tiling::DataType *mmType)
{
auto mmdt = GE_TO_MM_DTYPE.find(getype);
if (mmdt != GE_TO_MM_DTYPE.end()) {
*mmType = mmdt->second;
return true;
}
return false;
}
uint32_t MlaPrologTiling::CalcSingleCoreN(uint32_t n, uint32_t coreNum, uint32_t alignNum) const
{
return CeilDiv(n, alignNum * coreNum) * alignNum;
}
ge::graphStatus MlaPrologTiling::FillMatmul1Tiling()
{
if (scenarioInfo_.splitMFlag_ == 1U) {
singlecoreHeadSizeCq_ = baseShapeInfo_.hcqSize;
mm1BlockNum_ = aicNum_;
} else {
auto dataType = context_->weightDq.desc->GetDataType();
singlecoreHeadSizeCq_ =
CalcSingleCoreN(baseShapeInfo_.hcqSize, aicNum_, BLOCK_SIZE / DTYPE_TO_SIZE.at(dataType));
singlecoreHeadSizeCq_ = std::max(singlecoreHeadSizeCq_, 64U);
mm1BlockNum_ = CeilDiv(baseShapeInfo_.hcqSize, singlecoreHeadSizeCq_);
}
return ge::GRAPH_SUCCESS;
}
ge::graphStatus MlaPrologTiling::FillMatmul2Tiling()
{
if (scenarioInfo_.emptyTensorMode_ == EMPTY_TENSOR_MODE::EMPTY_CACHE) {
return ge::GRAPH_SUCCESS;
}
if (scenarioInfo_.splitMFlag_ == 1U) {
singlecoreHeadSizeCkvKr_ = baseShapeInfo_.hckvSize + baseShapeInfo_.drSize;
mm2BlockNum_ = aicNum_;
} else if (aicNum_ >= 9U) {
uint32_t baseN = 64U;
mm2BlockNum_ = (baseShapeInfo_.hckvSize + baseShapeInfo_.drSize) / baseN;
singlecoreHeadSizeCkvKr_ = baseN;
} else {
auto dataType = context_->weightDkvKr.desc->GetDataType();
singlecoreHeadSizeCkvKr_ = CalcSingleCoreN(baseShapeInfo_.hckvSize + baseShapeInfo_.drSize, aicNum_,
BLOCK_SIZE / DTYPE_TO_SIZE.at(dataType));
mm2BlockNum_ = CeilDiv(baseShapeInfo_.hckvSize + baseShapeInfo_.drSize, singlecoreHeadSizeCkvKr_);
}
return ge::GRAPH_SUCCESS;
}
ge::graphStatus MlaPrologTiling::FillMatmul3Tiling()
{
auto dataType = context_->weightUqQr.desc->GetDataType();
auto oriM = baseShapeInfo_.nSize * (baseShapeInfo_.dSize + baseShapeInfo_.drSize);
if (enableGroupComputeOpt_) {
singlecoreHeadSizeQcQr_ =
CalcSingleCoreN(baseShapeInfo_.nSize * baseShapeInfo_.dSize,
GROUP_COMPUTE_CUBE_NUM_PER_GROUP, baseShapeInfo_.dSize);
} else if (enableDequantOpt_) {
singlecoreHeadSizeQcQr_ = CalcSingleCoreN(oriM, aicNum_, baseShapeInfo_.dSize + baseShapeInfo_.drSize);
} else {
singlecoreHeadSizeQcQr_ = CalcSingleCoreN(oriM, aicNum_, BLOCK_SIZE / DTYPE_TO_SIZE.at(dataType));
}
mm3BlockNum_ = CeilDiv(oriM, singlecoreHeadSizeQcQr_);
if (scenarioInfo_.splitMFlag_ == 1U) {
singlecoreHeadSizeQcQr_ = oriM;
mm3BlockNum_ = aicNum_;
}
return ge::GRAPH_SUCCESS;
}
ge::graphStatus MlaPrologTiling::FillMatmul4Tiling()
{
if (scenarioInfo_.splitMFlag_ == 1U) {
singlecoreNumHeadSize_ = baseShapeInfo_.nSize;
mm4BlockNum_ = aicNum_;
} else {
singlecoreNumHeadSize_ = CeilDiv(baseShapeInfo_.nSize, aicNum_);
mm4BlockNum_ = CeilDiv(baseShapeInfo_.nSize, singlecoreNumHeadSize_);
}
return ge::GRAPH_SUCCESS;
}
ge::graphStatus MlaPrologTiling::ProcessBaseInputs()
{
stepBatchSize_ = std::min(128U, baseShapeInfo_.tSize);
if (scenarioInfo_.splitMFlag_ == 1U && (stepBatchSize_ > 0U) && (aicNum_ > 0U) &&
(baseShapeInfo_.tSize > 0U)) {
mSubSize_ = (baseShapeInfo_.tSize + aicNum_ - 1U) / aicNum_;
mSubCoreNum_ = baseShapeInfo_.tSize - (mSubSize_ - 1U) * aicNum_;
}
if (baseShapeInfo_.dSize == HIGH_THROUGHPUT__D_SIZE) {
stepNumHeadDequant_ = std::min(64U, baseShapeInfo_.nSize);
} else {
stepNumHeadDequant_ = std::min(16U, baseShapeInfo_.nSize);
}
vectorBlockNum_ = std::min(stepBatchSize_, aivNum_);
uint32_t cvRatio = aivNum_ / aicNum_;
if ((scenarioInfo_.quantMode_ == QUANT_MODE::PARTIAL_QUANT_KV_NO_QUANT ||
scenarioInfo_.quantMode_ == QUANT_MODE::PARTIAL_QUANT_KV_QUANT_PER_CHANNEL) &&
baseShapeInfo_.tSize == GROUP_COMPUTE_T_SIZE &&
baseShapeInfo_.nkvSize == GROUP_COMPUTE_N_SIZE &&
aivNum_ >= GROUP_COMPUTE_MIN_AIV_NUM &&
aicNum_ >= GROUP_COMPUTE_MIN_AIC_NUM &&
cvRatio != 1) {
enableGroupComputeOpt_ = true;
aivNum_ = 32U;
aicNum_ = 16U;
} else if ((context_->weightUqQr.desc->GetDataType() == ge::DT_INT8 &&
baseShapeInfo_.nSize >= GROUP_COMPUTE_N_SIZE) ||
context_->weightUqQr.desc->GetDataType() == ge::DT_FLOAT8_E4M3FN ||
context_->weightUqQr.desc->GetDataType() == ge::DT_HIFLOAT8) {
enableDequantOpt_ = true;
}
return ge::GRAPH_SUCCESS;
}
ge::graphStatus MlaPrologTiling::FillTiling()
{
baseParams_->tokenSize = baseShapeInfo_.tSize;
baseParams_->seq1Size = baseShapeInfo_.s1Size;
baseParams_->seq2Size = baseShapeInfo_.s2Size;
baseParams_->headSizeX = baseShapeInfo_.heSize;
baseParams_->headSizeCq = baseShapeInfo_.hcqSize;
baseParams_->headSizeCkv = baseShapeInfo_.hckvSize;
baseParams_->dtileSize = baseShapeInfo_.dtileSize;
baseParams_->headSizeQc = baseShapeInfo_.headSizeQc;
baseParams_->headSizeQr = baseShapeInfo_.headSizeQr;
baseParams_->headSizeKr = baseShapeInfo_.drSize;
baseParams_->numHeadSize = baseShapeInfo_.nSize;
baseParams_->numHeadKvSize = baseShapeInfo_.nkvSize;
baseParams_->dimHeadSizeQc = baseShapeInfo_.dSize;
baseParams_->dimHeadRope = baseShapeInfo_.drSize;
baseParams_->blockNum = baseShapeInfo_.blockNum;
baseParams_->blockSize = baseShapeInfo_.blockSize;
baseParams_->reciprocalCq = reciprocalCq_;
baseParams_->epsilonCq = epsilonCq_;
baseParams_->reciprocalCkv = reciprocalCkv_;
baseParams_->epsilonCkv = epsilonCkv_;
if (std::strncmp(context_->opType, V3_OP_NAME, OP_NAME_LEN) == 0) {
baseParams_->queryNormFlag = queryNormFlag_ ? 1U : 0U;
baseParams_->kvQuantMode = static_cast<uint32_t>(kvQuantMode_);
baseParams_->ckvkrRepoMode = static_cast<uint32_t>(ckvkrRepoMode_);
baseParams_->quantScaleRepoMode = static_cast<uint32_t>(quantSacleRepoMode_);
baseParams_->tileSize = tileSize_;
baseParams_->qcQrScale = qcQrScale_;
baseParams_->kcScale = kcScale_;
baseParams_->isQcQrScaleEnable = static_cast<uint16_t>(std::abs(qcQrScale_ - 1.0f) >= std::numeric_limits<float>::epsilon());
baseParams_->isKcScaleEnable = static_cast<uint16_t>(std::abs(kcScale_ - 1.0f) >= std::numeric_limits<float>::epsilon());
} else {
baseParams_->queryNormFlag = 0U;
baseParams_->kvQuantMode = 0U;
baseParams_->ckvkrRepoMode = 0U;
baseParams_->quantScaleRepoMode = 0U;
baseParams_->tileSize = 128U;
baseParams_->qcQrScale = 0;
baseParams_->kcScale = 0;
baseParams_->isQcQrScaleEnable = 0U;
baseParams_->isKcScaleEnable = 0U;
}
FillTilingCoreParams();
return ge::GRAPH_SUCCESS;
}
ge::graphStatus MlaPrologTiling::FillTilingCoreParams()
{
baseParams_->batchSize = baseShapeInfo_.bSize;
baseParams_->stepBatchSize = stepBatchSize_;
baseParams_->stepNumHeadDequant = stepNumHeadDequant_;
baseParams_->mSubSize = mSubSize_;
baseParams_->mSubCoreNum = mSubCoreNum_;
baseParams_->mm1BlockNum = mm1BlockNum_;
baseParams_->mm2BlockNum = mm2BlockNum_;
baseParams_->mm3BlockNum = mm3BlockNum_;
baseParams_->mm4BlockNum = mm4BlockNum_;
baseParams_->mm1SingleCoreN = singlecoreHeadSizeCq_;
baseParams_->mm2SingleCoreN = singlecoreHeadSizeCkvKr_;
baseParams_->mm3SingleCoreN = singlecoreHeadSizeQcQr_;
baseParams_->mm4SingleCoreBatch = singlecoreNumHeadSize_;
baseParams_->vectorBlockNum = vectorBlockNum_;
return ge::GRAPH_SUCCESS;
}
ge::graphStatus MlaPrologTiling::CalcWorkSpace()
{
workspaceSize_ = libapiSize_;
uint32_t mm1Mult = (scenarioInfo_.splitMFlag_ == 1U) ? mm1BlockNum_ : 1U;
uint32_t mm2Mult = (scenarioInfo_.splitMFlag_ == 1U) ? mm2BlockNum_ : 1U;
uint32_t mm3Mult = (scenarioInfo_.splitMFlag_ == 1U) ? mm3BlockNum_ : 1U;
uint32_t mm4Mult = (scenarioInfo_.splitMFlag_ == 1U) ? mm4BlockNum_ : 1U;
uint32_t dequantScaleMult = (scenarioInfo_.splitMFlag_ == 1U) ? static_cast<uint32_t>(aicNum_) : 1U;
if (scenarioInfo_.quantMode_ == QUANT_MODE::FULL_QUANT_KV_NO_QUANT ||
scenarioInfo_.quantMode_ == QUANT_MODE::FULL_QUANT_KV_QUANT_PER_TENSOR ||
scenarioInfo_.quantMode_ == QUANT_MODE::FULL_QUANT_KV_QUANT_PER_TILE ||
scenarioInfo_.quantMode_ == QUANT_MODE::MXFP8_FULL_QUANT_KV_NO_QUANT ||
scenarioInfo_.quantMode_ == QUANT_MODE::MXFP8_FULL_QUANT_KV_QUANT_PER_TENSOR ||
scenarioInfo_.quantMode_ == QUANT_MODE::MXFP8_FULL_QUANT_KV_QUANT_PER_TILE ||
scenarioInfo_.quantMode_ == QUANT_MODE::FP8_FULL_QUANT_KV_NO_QUANT ||
scenarioInfo_.quantMode_ == QUANT_MODE::FP8_FULL_QUANT_KV_QUANT_PER_TENSOR ||
scenarioInfo_.quantMode_ == QUANT_MODE::HIF8_FULL_QUANT_KV_NO_QUANT ||
scenarioInfo_.quantMode_ == QUANT_MODE::HIF8_FULL_QUANT_KV_QUANT_PER_TENSOR ||
scenarioInfo_.quantMode_ == QUANT_MODE::FP8_FULL_QUANT_KV_QUANT_PER_TILE ||
scenarioInfo_.quantMode_ == QUANT_MODE::HIF8_FULL_QUANT_KV_QUANT_PER_TILE) {
workspaceSize_ += static_cast<size_t>(stepBatchSize_) * static_cast<size_t>(baseShapeInfo_.hcqSize) *
mm1Mult * static_cast<size_t>(NUM_BYTES_INT32);
workspaceSize_ += static_cast<size_t>(stepBatchSize_) * static_cast<size_t>(baseShapeInfo_.hcqSize) *
mm1Mult * static_cast<size_t>(NUM_BYTES_BF16);
workspaceSize_ += static_cast<size_t>(stepBatchSize_) *
static_cast<size_t>(baseShapeInfo_.hckvSize + baseShapeInfo_.drSize) *
mm2Mult * static_cast<size_t>(NUM_BYTES_INT32);
if (scenarioInfo_.quantMode_ == QUANT_MODE::FULL_QUANT_KV_QUANT_PER_TENSOR ||
scenarioInfo_.quantMode_ == QUANT_MODE::MXFP8_FULL_QUANT_KV_QUANT_PER_TENSOR ||
scenarioInfo_.quantMode_ == QUANT_MODE::FP8_FULL_QUANT_KV_QUANT_PER_TENSOR ||
scenarioInfo_.quantMode_ == QUANT_MODE::HIF8_FULL_QUANT_KV_QUANT_PER_TENSOR) {
workspaceSize_ += static_cast<size_t>(stepBatchSize_) * static_cast<size_t>(baseShapeInfo_.nSize) *
static_cast<size_t>(baseShapeInfo_.hckvSize) * mm4Mult * static_cast<size_t>(NUM_BYTES_BF16);
}
} else {
workspaceSize_ += static_cast<size_t>(stepBatchSize_) * static_cast<size_t>(baseShapeInfo_.hcqSize) *
mm1Mult * static_cast<size_t>(NUM_BYTES_BF16) * static_cast<size_t>(2);
workspaceSize_ += static_cast<size_t>(stepBatchSize_) *
static_cast<size_t>(baseShapeInfo_.hckvSize + baseShapeInfo_.drSize) *
mm2Mult * static_cast<size_t>(NUM_BYTES_BF16);
}
workspaceSize_ += static_cast<size_t>(stepBatchSize_) *
static_cast<size_t>(baseShapeInfo_.headSizeQc + baseShapeInfo_.headSizeQr) *
mm3Mult * static_cast<size_t>(NUM_BYTES_INT32);
workspaceSize_ += static_cast<size_t>(stepBatchSize_) * static_cast<size_t>(baseShapeInfo_.nSize) *
static_cast<size_t>(baseShapeInfo_.dSize) * mm3Mult * static_cast<size_t>(NUM_BYTES_BF16);
if (enableGroupComputeOpt_ || enableDequantOpt_) {
workspaceSize_ += static_cast<size_t>(stepBatchSize_) * dequantScaleMult * static_cast<size_t>(BLOCK_SIZE);
}
if (context_->workSpaces) {
context_->workSpaces[0] = workspaceSize_;
}
OP_LOGI(context_->opName, "Tiling info: workspaceSize_ = %zu", workspaceSize_);
return ge::GRAPH_SUCCESS;
}
ge::graphStatus MlaPrologTiling::GenTilingKey() const
{
uint8_t typeValue = 0;
uint8_t quantType = 0;
if (scenarioInfo_.quantMode_ == QUANT_MODE::NO_QUANT) {
typeValue = 1U;
} else {
typeValue = 2U;
quantType = static_cast<uint8_t>(scenarioInfo_.quantMode_);
}
uint8_t cvMode = ASCENDC_TPL_MIX_AIC_1_2;
if (aivNum_ == aicNum_) {
cvMode = ASCENDC_TPL_MIX_AIC_1_1;
}
if (cvMode == ASCENDC_TPL_MIX_AIC_1_1 &&
(scenarioInfo_.quantMode_ != QUANT_MODE::NO_QUANT ||
(scenarioInfo_.cacheMode_ != CACHE_MODE::PA_BSND && scenarioInfo_.cacheMode_ != CACHE_MODE::PA_NZ))) {
OP_LOGE(context_->opName,
"CV1:1 mode only support quantMode is in {NO_QUANT} and cacheMode is in {PA_BSND,PA_NZ}, quantMode is %d, cacheMode is %u.",
scenarioInfo_.quantMode_, scenarioInfo_.cacheMode_);
return ge::GRAPH_FAILED;
}
if (scenarioInfo_.emptyTensorMode_ == EMPTY_TENSOR_MODE::EMPTY_QUERY) {
context_->tilingKey = GET_TPL_TILING_KEY(0, 0, 0, false, false,
static_cast<uint8_t>(scenarioInfo_.emptyTensorMode_), 0, 0, cvMode);
} else {
uint8_t cacheMode = scenarioInfo_.cacheMode_ == CACHE_MODE::TND ? 0 : static_cast<uint8_t>(scenarioInfo_.cacheMode_);
context_->tilingKey = GET_TPL_TILING_KEY(static_cast<uint8_t>(cacheMode), typeValue, quantType,
enableDequantOpt_, enableGroupComputeOpt_, static_cast<uint8_t>(scenarioInfo_.emptyTensorMode_),
static_cast<uint8_t>(scenarioInfo_.actualSeqMode_),
static_cast<uint8_t>(scenarioInfo_.splitMFlag_), cvMode);
OP_LOGI(context_->opName, "MlaProlog tilingKey args: "
"CACHE_MODE:%u, SCENARIO:%u, QUANT_MODE:%u, ENABLE_DEQUANT_OPTIONAL:%u, ENABLE_GROUP_COMPUTE_OPTIONAL:%u, "
"EMPTY_TENSOR_MODE:%u, ACTUAL_SEQ_LEN_MODE:%u, SPLIT_M_MODE:%u, CV_MODE:%u",
static_cast<uint8_t>(cacheMode), typeValue, quantType, enableDequantOpt_, enableGroupComputeOpt_,
static_cast<uint8_t>(scenarioInfo_.emptyTensorMode_), static_cast<uint8_t>(scenarioInfo_.actualSeqMode_),
static_cast<uint8_t>(scenarioInfo_.splitMFlag_), cvMode);
}
OP_LOGI(context_->opName, "MlaProlog tilingKey:%lu", context_->tilingKey);
return ge::GRAPH_SUCCESS;
}
ge::graphStatus MlaPrologTiling::RunBigKernelTiling(MlaPrologContext &context, MlaPrologTilingData* tilingData)
{
this->context_ = &context;
this->baseParams_ = &tilingData->baseParams;
MlaPrologTilingCheck tilingCheck_ {*context_, baseShapeInfo_, scenarioInfo_};
using StatusFunction = std::function<ge::graphStatus()>;
std::vector<StatusFunction> requiredTilingFuncs {
std::bind(&MlaPrologTiling::GetNpuInfo, this),
std::bind(&MlaPrologTilingCheck::CheckAttrs, &tilingCheck_),
std::bind(&MlaPrologTilingCheck::CheckSingleRequiredParam, &tilingCheck_),
std::bind(&MlaPrologTilingCheck::CheckCacheMode, &tilingCheck_),
std::bind(&MlaPrologTiling::SetShapeInfo, this),
std::bind(&MlaPrologTiling::SetScenarioInfo, this),
std::bind(&MlaPrologTilingCheck::CheckScenarParam, &tilingCheck_),
std::bind(&MlaPrologTilingCheck::CheckQuantMode, &tilingCheck_),
std::bind(&MlaPrologTilingCheck::CheckDims, &tilingCheck_),
std::bind(&MlaPrologTilingCheck::CheckSpecialScenarioParamShape, &tilingCheck_),
std::bind(&MlaPrologTilingCheck::CheckParamByScenario, &tilingCheck_),
std::bind(&MlaPrologTiling::SetAttrInfo, this),
std::bind(&MlaPrologTiling::ProcessBaseInputs, this),
};
for (const auto &func: requiredTilingFuncs) {
if (func() != ge::GRAPH_SUCCESS) {
return ge::GRAPH_FAILED;
}
}
if (scenarioInfo_.emptyTensorMode_ == EMPTY_TENSOR_MODE::EMPTY_QUERY) {
FillTiling();
if (context_->workSpaces) {
context_->workSpaces[0] = libapiSize_;
}
GenTilingKey();
context_->blockDim = 1U;
return ge::GRAPH_SUCCESS;
}
std::vector<StatusFunction> optionalTilingFuncs {
std::bind(&MlaPrologTiling::FillMatmul1Tiling, this),
std::bind(&MlaPrologTiling::FillMatmul2Tiling, this),
std::bind(&MlaPrologTiling::FillMatmul3Tiling, this),
std::bind(&MlaPrologTiling::FillMatmul4Tiling, this),
std::bind(&MlaPrologTiling::FillTiling, this),
std::bind(&MlaPrologTiling::CalcWorkSpace, this),
std::bind(&MlaPrologTiling::GenTilingKey, this)
};
for (const auto &func : optionalTilingFuncs) {
if (func() != ge::GRAPH_SUCCESS) {
return ge::GRAPH_FAILED;
}
}
context_->blockDim = aicNum_;
return ge::GRAPH_SUCCESS;
}
ge::graphStatus MlaPrologTiling::ConvertContext(gert::TilingContext &context, MlaPrologContext &mlaPrologContext)
{
OP_CHECK_IF(context.GetNodeName() == nullptr,
OP_LOGE(V1_OP_NAME, "OpName got from TilingContext is nullptr."), return ge::GRAPH_FAILED);
mlaPrologContext.opName = context.GetNodeName();
mlaPrologContext.opType = context.GetNodeType();
mlaPrologContext.platformInfo = context.GetPlatformInfo();
ConvertRequiredParams(context, mlaPrologContext);
ConvertOptionalParams(context, mlaPrologContext);
auto attrs = context.GetAttrs();
OP_CHECK_IF(attrs == nullptr,
OP_LOGE(context.GetNodeName(), "Attrs got from ge is nullptr."), return ge::GRAPH_FAILED);
mlaPrologContext.rmsNormEspilonCq = attrs->GetAttrPointer<float>(RMS_NORM_EPSILON_CQ_ATTR_INDEX);
mlaPrologContext.rmsNormEspilonCkv = attrs->GetAttrPointer<float>(RMS_NORM_EPSILON_CKV_ATTR_INDEX);
mlaPrologContext.cacheMode = attrs->GetStr(CACHE_MODE_ATTR_INDEX);
if (std::strncmp(mlaPrologContext.opType, V3_OP_NAME, OP_NAME_LEN) == 0) {
mlaPrologContext.queryNormFlag = attrs->GetAttrPointer<bool>(QUERY_NORM_FLAG_ATTR_INDEX);
mlaPrologContext.weightQuantMode = attrs->GetAttrPointer<int64_t>(WEIGHT_QUANT_MODE_ATTR_INDEX);
mlaPrologContext.kvQuantMode = attrs->GetAttrPointer<int64_t>(KV_CACHE_QUANT_MODE_ATTR_INDEX);
mlaPrologContext.queryQuantMode = attrs->GetAttrPointer<int64_t>(QUERY_QUANT_MODE_ATTR_INDEX);
mlaPrologContext.ckvkrRepoMode = attrs->GetAttrPointer<int64_t>(CKVKR_REPO_MODE_ATTR_INDEX);
mlaPrologContext.quantScaleRepoMode = attrs->GetAttrPointer<int64_t>(QUANT_SCALE_REPO_MODE_ATTR_INDEX);
mlaPrologContext.tileSize = attrs->GetAttrPointer<int64_t>(TILE_SIZE_ATTR_INDEX);
mlaPrologContext.qcQrScale = attrs->GetAttrPointer<float>(QC_QR_SCALE_ATTR_INDEX);
mlaPrologContext.kcScale = attrs->GetAttrPointer<float>(KC_SCALE_ATTR_INDEX);
} else {
mlaPrologContext.queryNormFlag = nullptr;
mlaPrologContext.weightQuantMode = nullptr;
mlaPrologContext.kvQuantMode = nullptr;
mlaPrologContext.queryQuantMode = nullptr;
mlaPrologContext.ckvkrRepoMode = nullptr;
mlaPrologContext.quantScaleRepoMode = nullptr;
mlaPrologContext.tileSize = nullptr;
mlaPrologContext.qcQrScale = nullptr;
mlaPrologContext.kcScale = nullptr;
}
OP_CHECK_IF(context.GetWorkspaceSizes(1) == nullptr,
OPS_REPORT_VECTOR_INNER_ERR(context.GetNodeName(), "workSpaceSize got from ge is nullptr"),
return ge::GRAPH_FAILED);
mlaPrologContext.workSpaces = context.GetWorkspaceSizes(1);
return ge::GRAPH_SUCCESS;
}
void MlaPrologTiling::ConvertRequiredParams(gert::TilingContext &context, MlaPrologContext &mlaPrologContext)
{
mlaPrologContext.tokenX.desc = context.GetRequiredInputDesc(TOKEN_X_INPUT_INDEX);
mlaPrologContext.tokenX.shape = context.GetRequiredInputShape(TOKEN_X_INPUT_INDEX);
mlaPrologContext.weightDq.desc = context.GetRequiredInputDesc(WEIGHT_DQ_INPUT_INDEX);
mlaPrologContext.weightDq.shape = context.GetRequiredInputShape(WEIGHT_DQ_INPUT_INDEX);
mlaPrologContext.weightUqQr.desc = context.GetRequiredInputDesc(WEIGHT_UQ_QR_INPUT_INDEX);
mlaPrologContext.weightUqQr.shape = context.GetRequiredInputShape(WEIGHT_UQ_QR_INPUT_INDEX);
mlaPrologContext.weightUk.desc = context.GetRequiredInputDesc(WEIGHT_UK_INPUT_INDEX);
mlaPrologContext.weightUk.shape = context.GetRequiredInputShape(WEIGHT_UK_INPUT_INDEX);
mlaPrologContext.weightDkvKr.desc = context.GetRequiredInputDesc(WEIGHT_DKV_KR_INPUT_INDEX);
mlaPrologContext.weightDkvKr.shape = context.GetRequiredInputShape(WEIGHT_DKV_KR_INPUT_INDEX);
mlaPrologContext.rmsnormGammaCq.desc = context.GetRequiredInputDesc(RMSNORM_GAMMA_CQ_INPUT_INDEX);
mlaPrologContext.rmsnormGammaCq.shape = context.GetRequiredInputShape(RMSNORM_GAMMA_CQ_INPUT_INDEX);
mlaPrologContext.rmsnormGammaCkv.desc = context.GetRequiredInputDesc(RMS_NORM_GAMMA_CKV_INPUT_INDEX);
mlaPrologContext.rmsnormGammaCkv.shape = context.GetRequiredInputShape(RMS_NORM_GAMMA_CKV_INPUT_INDEX);
mlaPrologContext.ropeSin.desc = context.GetRequiredInputDesc(ROPE_SIN_INPUT_INDEX);
mlaPrologContext.ropeSin.shape = context.GetRequiredInputShape(ROPE_SIN_INPUT_INDEX);
mlaPrologContext.ropeCos.desc = context.GetRequiredInputDesc(ROPE_COS_INPUT_INDEX);
mlaPrologContext.ropeCos.shape = context.GetRequiredInputShape(ROPE_COS_INPUT_INDEX);
if (std::strncmp(mlaPrologContext.opType, V3_OP_NAME, OP_NAME_LEN) == 0) {
mlaPrologContext.kvCache.desc = context.GetRequiredInputDesc(KV_CACHE_INPUT_INDEX_V3);
mlaPrologContext.kvCache.shape = context.GetRequiredInputShape(KV_CACHE_INPUT_INDEX_V3);
mlaPrologContext.krCache.desc = context.GetRequiredInputDesc(KR_CACHE_INPUT_INDEX_V3);
mlaPrologContext.krCache.shape = context.GetRequiredInputShape(KR_CACHE_INPUT_INDEX_V3);
} else {
mlaPrologContext.cacheIndex.desc = context.GetRequiredInputDesc(CACHE_INDEX_INPUT_INDEX);
mlaPrologContext.cacheIndex.shape = context.GetRequiredInputShape(CACHE_INDEX_INPUT_INDEX);
mlaPrologContext.kvCache.desc = context.GetRequiredInputDesc(KV_CACHE_INPUT_INDEX);
mlaPrologContext.kvCache.shape = context.GetRequiredInputShape(KV_CACHE_INPUT_INDEX);
mlaPrologContext.krCache.desc = context.GetRequiredInputDesc(KR_CACHE_INPUT_INDEX);
mlaPrologContext.krCache.shape = context.GetRequiredInputShape(KR_CACHE_INPUT_INDEX);
}
mlaPrologContext.query.desc = context.GetOutputDesc(QUERY_OUTPUT_INDEX);
mlaPrologContext.query.shape = context.GetOutputShape(QUERY_OUTPUT_INDEX);
mlaPrologContext.queryRope.desc = context.GetOutputDesc(QUERY_ROPE_OUTPUT_INDEX);
mlaPrologContext.queryRope.shape = context.GetOutputShape(QUERY_ROPE_OUTPUT_INDEX);
mlaPrologContext.kvCacheOut.desc = context.GetOutputDesc(KV_CACHE_OUT_OUTPUT_INDEX);
mlaPrologContext.kvCacheOut.shape = context.GetOutputShape(KV_CACHE_OUT_OUTPUT_INDEX);
mlaPrologContext.krCacheOut.desc = context.GetOutputDesc(KR_CACHE_OUT_OUTPUT_INDEX);
mlaPrologContext.krCacheOut.shape = context.GetOutputShape(KR_CACHE_OUT_OUTPUT_INDEX);
}
void MlaPrologTiling::ConvertOptionalParams(gert::TilingContext &context, MlaPrologContext &mlaPrologContext)
{
if (std::strncmp(mlaPrologContext.opType, V3_OP_NAME, OP_NAME_LEN) == 0) {
mlaPrologContext.cacheIndex.desc = context.GetRequiredInputDesc(CACHE_INDEX_INPUT_INDEX_V3);
mlaPrologContext.cacheIndex.shape = context.GetRequiredInputShape(CACHE_INDEX_INPUT_INDEX_V3);
}
mlaPrologContext.dequantScaleX.desc = context.GetOptionalInputDesc(DEQUANT_SCALE_X_INDEX);
mlaPrologContext.dequantScaleX.shape = context.GetOptionalInputShape(DEQUANT_SCALE_X_INDEX);
mlaPrologContext.dequantScaleWDq.desc = context.GetOptionalInputDesc(DEQUANT_SCALE_W_DQ_INDEX);
mlaPrologContext.dequantScaleWDq.shape = context.GetOptionalInputShape(DEQUANT_SCALE_W_DQ_INDEX);
mlaPrologContext.dequantScaleWUqQr.desc = context.GetOptionalInputDesc(DEQUANT_SCALE_W_UQ_QR_INDEX);
mlaPrologContext.dequantScaleWUqQr.shape = context.GetOptionalInputShape(DEQUANT_SCALE_W_UQ_QR_INDEX);
mlaPrologContext.dequantScaleWDkvKr.desc = context.GetOptionalInputDesc(DEQUANT_SCALE_W_DKV_KR_INDEX);
mlaPrologContext.dequantScaleWDkvKr.shape = context.GetOptionalInputShape(DEQUANT_SCALE_W_DKV_KR_INDEX);
mlaPrologContext.quantScaleCkv.desc = context.GetOptionalInputDesc(QUANT_SCALE_CKV_INDEX);
mlaPrologContext.quantScaleCkv.shape = context.GetOptionalInputShape(QUANT_SCALE_CKV_INDEX);
mlaPrologContext.quantScaleCkr.desc = context.GetOptionalInputDesc(QUANT_SCALE_CKR_INDEX);
mlaPrologContext.quantScaleCkr.shape = context.GetOptionalInputShape(QUANT_SCALE_CKR_INDEX);
mlaPrologContext.smoothScalesCq.desc = context.GetOptionalInputDesc(SMOOTH_SCALES_CQ_INDEX);
mlaPrologContext.smoothScalesCq.shape = context.GetOptionalInputShape(SMOOTH_SCALES_CQ_INDEX);
mlaPrologContext.actualSeqLen.desc = context.GetOptionalInputDesc(ACTUAL_SEQ_LEN_INDEX);
mlaPrologContext.actualSeqLen.shape = context.GetOptionalInputShape(ACTUAL_SEQ_LEN_INDEX);
mlaPrologContext.kNopeClipAlpha.desc = context.GetOptionalInputDesc(K_NOPE_CLIP_ALPHA_INDEX);
mlaPrologContext.kNopeClipAlpha.shape = context.GetOptionalInputShape(K_NOPE_CLIP_ALPHA_INDEX);
if (std::strncmp(mlaPrologContext.opType, V1_OP_NAME, OP_NAME_LEN) == 0) {
mlaPrologContext.dequantScaleQNope.desc = nullptr;
mlaPrologContext.dequantScaleQNope.shape = nullptr;
} else {
mlaPrologContext.dequantScaleQNope.desc = context.GetOutputDesc(DEQUANT_SCALE_Q_NOPE_OUTPUT_INDEX);
mlaPrologContext.dequantScaleQNope.shape = context.GetOutputShape(DEQUANT_SCALE_Q_NOPE_OUTPUT_INDEX);
}
if (std::strncmp(mlaPrologContext.opType, V3_OP_NAME, OP_NAME_LEN) == 0) {
mlaPrologContext.queryNorm.desc = context.GetOutputDesc(QUERY_NORM_OUTPUT_INDEX);
mlaPrologContext.queryNorm.shape = context.GetOutputShape(QUERY_NORM_OUTPUT_INDEX);
mlaPrologContext.dequantScaleQNorm.desc = context.GetOutputDesc(DEQUANT_SCALE_Q_NORM_OUTPUT_INDEX);
mlaPrologContext.dequantScaleQNorm.shape = context.GetOutputShape(DEQUANT_SCALE_Q_NORM_OUTPUT_INDEX);
} else {
mlaPrologContext.queryNorm.desc = nullptr;
mlaPrologContext.queryNorm.shape = nullptr;
mlaPrologContext.dequantScaleQNorm.desc = nullptr;
mlaPrologContext.dequantScaleQNorm.shape = nullptr;
}
}
MLA_EXTERN_C ge::graphStatus TilingMlaProlog(gert::TilingContext *context)
{
OP_CHECK_IF(context == nullptr,
OPS_REPORT_VECTOR_INNER_ERR(V1_OP_NAME, "Context is nullptr."), return ge::GRAPH_FAILED);
MlaPrologContext mlaPrologContext{};
OP_CHECK_IF(MlaPrologTiling::ConvertContext(*context, mlaPrologContext) != ge::GRAPH_SUCCESS,
OP_LOGE(context->GetNodeName(), "Error occurred while converting tilingContext to MlaProlog context."),
return ge::GRAPH_FAILED);
MlaPrologTiling mlaPrologTiling;
MlaPrologTilingData* tilingData = context->GetTilingData<MlaPrologTilingData>();
OP_CHECK_IF(tilingData == nullptr,
OPS_REPORT_VECTOR_INNER_ERR(context->GetNodeName(), "TilingData is nullptr."),
return ge::GRAPH_FAILED);
if (mlaPrologTiling.RunBigKernelTiling(mlaPrologContext, tilingData) == ge::SUCCESS) {
context->SetTilingKey(mlaPrologContext.tilingKey);
context->SetBlockDim(mlaPrologContext.blockDim);
return ge::GRAPH_SUCCESS;
}
return ge::GRAPH_FAILED;
}
}