* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file mla_prolog_tiling.h
* \brief
*/
#ifndef MLA_PROLOG_TILING_H
#define MLA_PROLOG_TILING_H
#include <cstdint>
#include <string>
#include <vector>
#include <map>
#include <unordered_map>
#include <set>
#include <sstream>
#include <algorithm>
#include <unordered_set>
#include "register/tilingdata_base.h"
#include "tiling/tiling_api.h"
#include "op_host/data_copy_transpose_tiling.h"
#include "exe_graph/runtime/tiling_context.h"
#include "register/op_def_registry.h"
#include "../op_kernel/mla_prolog_template_tiling_key.h"
#include "../op_kernel/mla_prolog_tiling_data.h"
#include "platform/soc_spec.h"
#ifdef ASCENDC_OP_TEST
#define MLA_EXTERN_C extern "C"
#else
#define MLA_EXTERN_C
#endif
namespace optiling {
constexpr uint32_t TOKEN_X_INPUT_INDEX = 0;
constexpr uint32_t WEIGHT_DQ_INPUT_INDEX = 1;
constexpr uint32_t WEIGHT_UQ_QR_INPUT_INDEX = 2;
constexpr uint32_t WEIGHT_UK_INPUT_INDEX = 3;
constexpr uint32_t WEIGHT_DKV_KR_INPUT_INDEX = 4;
constexpr uint32_t RMSNORM_GAMMA_CQ_INPUT_INDEX = 5;
constexpr uint32_t RMS_NORM_GAMMA_CKV_INPUT_INDEX = 6;
constexpr uint32_t ROPE_SIN_INPUT_INDEX = 7;
constexpr uint32_t ROPE_COS_INPUT_INDEX = 8;
constexpr uint32_t CACHE_INDEX_INPUT_INDEX = 9;
constexpr uint32_t KV_CACHE_INPUT_INDEX = 10;
constexpr uint32_t KR_CACHE_INPUT_INDEX = 11;
constexpr uint32_t KV_CACHE_INPUT_INDEX_V3 = 9;
constexpr uint32_t KR_CACHE_INPUT_INDEX_V3 = 10;
constexpr uint32_t CACHE_INDEX_INPUT_INDEX_V3 = 11;
constexpr uint32_t DEQUANT_SCALE_X_INDEX = 12;
constexpr uint32_t DEQUANT_SCALE_W_DQ_INDEX = 13;
constexpr uint32_t DEQUANT_SCALE_W_UQ_QR_INDEX = 14;
constexpr uint32_t DEQUANT_SCALE_W_DKV_KR_INDEX = 15;
constexpr uint32_t QUANT_SCALE_CKV_INDEX = 16;
constexpr uint32_t QUANT_SCALE_CKR_INDEX = 17;
constexpr uint32_t SMOOTH_SCALES_CQ_INDEX = 18;
constexpr uint32_t ACTUAL_SEQ_LEN_INDEX = 19;
constexpr uint32_t K_NOPE_CLIP_ALPHA_INDEX = 20;
constexpr uint32_t QUERY_OUTPUT_INDEX = 0;
constexpr uint32_t QUERY_ROPE_OUTPUT_INDEX = 1;
constexpr uint32_t KV_CACHE_OUT_OUTPUT_INDEX = 2;
constexpr uint32_t KR_CACHE_OUT_OUTPUT_INDEX = 3;
constexpr uint32_t DEQUANT_SCALE_Q_NOPE_OUTPUT_INDEX = 4;
constexpr uint32_t QUERY_NORM_OUTPUT_INDEX = 5;
constexpr uint32_t DEQUANT_SCALE_Q_NORM_OUTPUT_INDEX = 6;
constexpr uint32_t RMS_NORM_EPSILON_CQ_ATTR_INDEX = 0;
constexpr uint32_t RMS_NORM_EPSILON_CKV_ATTR_INDEX = 1;
constexpr uint32_t CACHE_MODE_ATTR_INDEX = 2;
constexpr uint32_t QUERY_NORM_FLAG_ATTR_INDEX = 3;
constexpr uint32_t WEIGHT_QUANT_MODE_ATTR_INDEX = 4;
constexpr uint32_t KV_CACHE_QUANT_MODE_ATTR_INDEX = 5;
constexpr uint32_t QUERY_QUANT_MODE_ATTR_INDEX = 6;
constexpr uint32_t CKVKR_REPO_MODE_ATTR_INDEX = 7;
constexpr uint32_t QUANT_SCALE_REPO_MODE_ATTR_INDEX = 8;
constexpr uint32_t TILE_SIZE_ATTR_INDEX = 9;
constexpr uint32_t QC_QR_SCALE_ATTR_INDEX = 10;
constexpr uint32_t KC_SCALE_ATTR_INDEX = 11;
constexpr uint32_t MLA_PROLOG_DIM_INDEX_0 = 0;
constexpr uint32_t MLA_PROLOG_DIM_INDEX_1 = 1;
constexpr uint32_t MLA_PROLOG_DIM_INDEX_2 = 2;
constexpr uint32_t MLA_PROLOG_DIM_INDEX_3 = 3;
constexpr uint32_t MLA_PROLOG_DIM_NUM_0 = 0;
constexpr uint32_t MLA_PROLOG_DIM_NUM_1 = 1;
constexpr uint32_t MLA_PROLOG_DIM_NUM_2 = 2;
constexpr uint32_t MLA_PROLOG_DIM_NUM_3 = 3;
constexpr uint32_t MLA_PROLOG_DIM_NUM_4 = 4;
constexpr char CACHE_MODE_BSND[] {"BSND"};
constexpr char CACHE_MODE_TND[] {"TND"};
constexpr char CACHE_MODE_PA_BSND[] {"PA_BSND"};
constexpr char CACHE_MODE_PA_NZ[] {"PA_NZ"};
constexpr char CACHE_MODE_PA_BLK_BSND[] {"PA_BLK_BSND"};
constexpr char CACHE_MODE_PA_BLK_NZ[] {"PA_BLK_NZ"};
constexpr char V1_OP_NAME[] {"MlaProlog"};
constexpr char V2_OP_NAME[] {"MlaPrologV2"};
constexpr char V3_OP_NAME[] {"MlaPrologV3"};
constexpr uint32_t CACHE_MODE_LEN = std::max({sizeof(CACHE_MODE_BSND),
sizeof(CACHE_MODE_TND),
sizeof(CACHE_MODE_PA_BSND),
sizeof(CACHE_MODE_PA_NZ),
sizeof(CACHE_MODE_PA_BLK_BSND),
sizeof(CACHE_MODE_PA_BLK_NZ)});
constexpr uint32_t OP_NAME_LEN = std::max({sizeof(V1_OP_NAME),
sizeof(V2_OP_NAME),
sizeof(V3_OP_NAME)});
struct MlaPrologBaseShapeInfo {
uint32_t bSize = 0;
uint32_t s1Size = 0;
uint32_t tSize = 0;
uint32_t s2Size = 0;
uint32_t heSize = 0;
uint32_t hcqSize = 0;
uint32_t hckvSize = 0;
uint32_t headSizeQc = 0;
uint32_t headSizeQr = 0;
uint32_t headSizeUqQr = 0;
uint32_t nSize = 0;
uint32_t nkvSize = 0;
uint32_t dSize = 0;
uint32_t blockNum = 0;
uint32_t blockSize = 0;
uint32_t drSize = 0;
uint32_t dtileSize = 0;
};
enum class CACHE_MODE:uint8_t {
BSND = 0,
PA_BSND = 1,
PA_NZ = 2,
PA_BLK_BSND = 3,
PA_BLK_NZ = 4,
TND = 5
};
enum class EMPTY_TENSOR_MODE:uint8_t {
NON_EMPTY = 0,
EMPTY_CACHE = 1,
EMPTY_QUERY = 2
};
enum class ACTUAL_SEQ_MODE:uint8_t {
DISABLED = 0,
EN_Q_LEN = 1,
};
enum class QUANT_MODE:int8_t {
ERROR_MODE = -1,
NO_QUANT = 0,
PARTIAL_QUANT_KV_NO_QUANT = 1,
PARTIAL_QUANT_KV_QUANT_PER_CHANNEL = 2,
FULL_QUANT_KV_NO_QUANT = 3,
FULL_QUANT_KV_QUANT_PER_TENSOR = 4,
PARTIAL_QUANT_KV_QUANT_PER_TILE = 5,
FULL_QUANT_KV_QUANT_PER_TILE = 6,
MXFP8_FULL_QUANT_KV_NO_QUANT = 7,
MXFP8_FULL_QUANT_KV_QUANT_PER_TENSOR = 8,
MXFP8_FULL_QUANT_KV_QUANT_PER_TILE = 9,
FP8_FULL_QUANT_KV_NO_QUANT = 10,
FP8_FULL_QUANT_KV_QUANT_PER_TENSOR = 11,
HIF8_FULL_QUANT_KV_NO_QUANT = 12,
HIF8_FULL_QUANT_KV_QUANT_PER_TENSOR = 13,
FP8_FULL_QUANT_KV_QUANT_PER_TILE = 14,
HIF8_FULL_QUANT_KV_QUANT_PER_TILE = 15
};
enum class WEIGHT_QUANT_MODE:uint8_t {
NO_QUANT = 0,
PARTIAL_QUANT = 1,
FULL_QUANT = 2,
MXFP8_FULL_QUANT = 3,
FP8_FULL_QUANT = 4,
HIF8_FULL_QUANT = 5
};
enum class KV_QUANT_MODE:uint8_t {
NO_QUANT = 0,
PER_TENSOR = 1,
PER_CHANNEL = 2,
PER_TILE = 3
};
enum class QUERY_QUANT_MODE:uint8_t {
NO_QUANT = 0,
PER_TOKEN_HEAD = 1
};
enum class CKVKR_REPO_MODE:uint8_t {
DIVIDE = 0,
COMBINE = 1
};
enum class QUANT_SCALE_REPO_MODE:uint8_t {
DIVIDE = 0,
COMBINE = 1
};
struct MlaPrologScenarioInfo {
bool isV1Flag_;
bool batchSeqFusedFlag_;
uint8_t splitMFlag_;
QUANT_MODE quantMode_;
CACHE_MODE cacheMode_;
EMPTY_TENSOR_MODE emptyTensorMode_;
ACTUAL_SEQ_MODE actualSeqMode_;
};
struct MlaPrologCompileInfo {
int64_t core_num;
};
struct BaseParaInfo {
const gert::CompileTimeTensorDesc *desc;
const gert::StorageShape *shape;
};
struct RequiredParaInfo : BaseParaInfo {
};
struct OptionalParaInfo : BaseParaInfo {
};
constexpr uint32_t BLOCK_SIZE = 32;
constexpr uint32_t NUM_BYTES_BF16 = 2;
constexpr uint32_t NUM_BYTES_INT8 = 1;
constexpr uint32_t NUM_BYTES_INT32 = 4;
constexpr uint32_t NUM_BYTES_FP32 = 4;
constexpr uint32_t GROUP_COMPUTE_CUBE_NUM_PER_GROUP = 8U;
constexpr uint32_t HIGH_THROUGHPUT__D_SIZE = 128U;
constexpr uint32_t GROUP_COMPUTE_T_SIZE = 1U;
constexpr uint32_t GROUP_COMPUTE_NKV_SIZE = 8U;
constexpr uint32_t GROUP_COMPUTE_MIN_AIC_NUM = 16U;
constexpr uint32_t GROUP_COMPUTE_MIN_AIV_NUM = 32U;
constexpr uint32_t GROUP_COMPUTE_N_SIZE = 8U;
struct MlaPrologContext {
const char *opName;
const char *opType;
fe::PlatFormInfos *platformInfo;
RequiredParaInfo tokenX;
RequiredParaInfo weightDq;
RequiredParaInfo weightUqQr;
RequiredParaInfo weightUk;
RequiredParaInfo weightDkvKr;
RequiredParaInfo rmsnormGammaCq;
RequiredParaInfo rmsnormGammaCkv;
RequiredParaInfo ropeSin;
RequiredParaInfo ropeCos;
RequiredParaInfo kvCache;
RequiredParaInfo krCache;
OptionalParaInfo cacheIndex;
OptionalParaInfo dequantScaleX;
OptionalParaInfo dequantScaleWDq;
OptionalParaInfo dequantScaleWUqQr;
OptionalParaInfo dequantScaleWDkvKr;
OptionalParaInfo quantScaleCkv;
OptionalParaInfo quantScaleCkr;
OptionalParaInfo smoothScalesCq;
OptionalParaInfo actualSeqLen;
OptionalParaInfo kNopeClipAlpha;
RequiredParaInfo query;
RequiredParaInfo queryRope;
RequiredParaInfo kvCacheOut;
RequiredParaInfo krCacheOut;
OptionalParaInfo dequantScaleQNope;
OptionalParaInfo queryNorm;
OptionalParaInfo dequantScaleQNorm;
const float *rmsNormEspilonCq;
const float *rmsNormEspilonCkv;
const char *cacheMode;
const bool *queryNormFlag;
const int64_t *weightQuantMode;
const int64_t *kvQuantMode;
const int64_t *queryQuantMode;
const int64_t *ckvkrRepoMode;
const int64_t *quantScaleRepoMode;
const int64_t *tileSize;
const float *qcQrScale;
const float *kcScale;
size_t *workSpaces;
uint64_t tilingKey;
uint32_t blockDim;
};
class MlaPrologTiling {
public:
MlaPrologTiling() = default;
~MlaPrologTiling() = default;
ge::graphStatus RunBigKernelTiling(MlaPrologContext &context, MlaPrologTilingData *tilingData);
static ge::graphStatus ConvertContext(gert::TilingContext &context, MlaPrologContext &mlaPrologContext);
private:
static void ConvertRequiredParams(gert::TilingContext &context, MlaPrologContext &mlaPrologContext);
static void ConvertOptionalParams(gert::TilingContext &context, MlaPrologContext &mlaPrologContext);
ge::graphStatus GetNpuInfo();
ge::graphStatus SetScenarioInfo();
ge::graphStatus SetAttrInfo();
QUANT_MODE GetQuantizationMode() const;
QUANT_MODE GetQuantizationModeV3() const;
QUANT_MODE GetQuantizationModeV3Mxfp8() const;
ge::graphStatus SetShapeInfo();
ge::graphStatus ProcessBaseInputs();
ge::graphStatus FillTiling();
ge::graphStatus FillTilingCoreParams();
ge::graphStatus FillMatmul1Tiling();
ge::graphStatus FillMatmul2Tiling();
ge::graphStatus FillMatmul3Tiling();
ge::graphStatus FillMatmul4Tiling();
uint32_t CalcSingleCoreN(uint32_t n, uint32_t coreNum, uint32_t alignNum = 16) const;
bool GetMatmulType(ge::DataType getype, matmul_tiling::DataType *mmType);
ge::graphStatus CalcWorkSpace();
ge::graphStatus GenTilingKey() const;
NpuArch GetCurNpuArch() const;
MlaPrologBaseShapeInfo baseShapeInfo_;
MlaPrologScenarioInfo scenarioInfo_;
uint32_t stepBatchSize_ = 0;
uint32_t stepNumHeadDequant_ = 0;
uint32_t mSubSize_ = 0;
uint32_t mSubCoreNum_ = 0;
uint32_t mm1BlockNum_ = 0;
uint32_t mm2BlockNum_ = 0;
uint32_t mm3BlockNum_ = 0;
uint32_t mm4BlockNum_ = 0;
uint32_t vectorBlockNum_ = 0;
uint32_t singlecoreHeadSizeCq_ = 0;
uint32_t singlecoreHeadSizeQcQr_ = 0;
uint32_t singlecoreHeadSizeCkvKr_ = 0;
uint32_t singlecoreNumHeadSize_ = 0;
float reciprocalCq_ = 0.00001f;
float epsilonCq_ = 1.0f;
float reciprocalCkv_ = 0.00001f;
float epsilonCkv_ = 1.0f;
bool queryNormFlag_ = false;
WEIGHT_QUANT_MODE weightQuantMode_ = WEIGHT_QUANT_MODE::NO_QUANT;
KV_QUANT_MODE kvQuantMode_ = KV_QUANT_MODE::NO_QUANT;
QUERY_QUANT_MODE queryQuantMode_ = QUERY_QUANT_MODE::NO_QUANT;
CKVKR_REPO_MODE ckvkrRepoMode_ = CKVKR_REPO_MODE::DIVIDE;
QUANT_SCALE_REPO_MODE quantSacleRepoMode_ = QUANT_SCALE_REPO_MODE::DIVIDE;
uint32_t tileSize_ = 128;
float qcQrScale_ = 1.0f;
float kcScale_ = 1.0f;
ge::DataType mmDateType_ = ge::DT_BF16;
bool enableDequantOpt_ = false;
bool enableGroupComputeOpt_ = false;
size_t ubSize_ = 0;
size_t l1Size_ = 0;
size_t l0cSize_ = 0;
size_t l0bSize_ = 0;
uint32_t coreNum_ = 0;
uint32_t aicNum_ = 0;
uint32_t aivNum_ = 0;
size_t libapiSize_ = 0;
size_t workspaceSize_ = 0;
MlaPrologContext *context_ = nullptr;
MlaPrologBaseParams *baseParams_ = nullptr;
};
ge::graphStatus TilingPrepareForMlaProlog(gert::TilingParseContext *context);
MLA_EXTERN_C ge::graphStatus TilingMlaProlog(gert::TilingContext *context);
}
#endif