* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file matmul_tiling_algorithm.h
* \brief
*/
#ifndef IMPL_MATMUL_TILING_MATMUL_TILING_ALGORITHM_H
#define IMPL_MATMUL_TILING_MATMUL_TILING_ALGORITHM_H
#include <cstdint>
#include <string>
#include <algorithm>
#include <climits>
#include <cmath>
#include <array>
#include <vector>
#include "include/adv_api/matmul/matmul_tiling.h"
#include "include/adv_api/matmul/matmul_tiling_base.h"
namespace matmul_tiling {
constexpr int32_t DB_ON = 2;
constexpr int32_t DB_OFF = 1;
constexpr int32_t L1_FACTORS_LEN = 6;
constexpr int32_t L0PARAS_COMBO_LEN = 2;
constexpr float MATMUL_EPSILON = 1.0E-4F;
constexpr float MEMORY_EPSILON = 0.05F;
struct MatmulTemplateCfg {
int32_t l0aDB = DB_ON;
int32_t l0bDB = DB_ON;
int32_t l0cDB = DB_OFF;
int32_t l1DB = DB_ON;
bool factorSplit = true;
bool kSplit = false;
};
struct L1StatusPack {
int32_t kAL1 = 1;
int32_t kBL1 = 1;
int32_t mL1 = 1;
int32_t nL1 = 1;
int32_t mAL1 = 1;
int32_t nBL1 = 1;
int32_t dbAL1 = 1;
int32_t dbBL1 = 1;
int32_t aL1Size = 0;
int32_t bL1Size = 0;
int32_t aL1Times = 1;
int32_t bL1Times = 1;
int32_t allTimes = 1;
int32_t loadSize = 0;
int32_t maxMAL1 = 1;
int32_t maxNBL1 = 1;
int32_t maxKAL1 = 1;
int32_t maxKBL1 = 1;
bool bothFullLoad = false;
bool aL1FullLoad = false;
bool bL1FullLoad = false;
bool aL1KFullLoad = false;
bool bL1KFullLoad = false;
int32_t channelWiseTimes = 0;
void SetStatus(const int32_t (&tmpL1Factors)[L1_FACTORS_LEN])
{
this->kAL1 = tmpL1Factors[0];
this->kBL1 = tmpL1Factors[1];
this->mAL1 = tmpL1Factors[2];
this->nBL1 = tmpL1Factors[3];
this->dbAL1 = tmpL1Factors[4];
this->dbBL1 = tmpL1Factors[5];
}
};
struct L0StatusPack {
int32_t mL0 = 1;
int32_t nL0 = 1;
int32_t kL0 = 1;
int32_t batchL0 = 1;
int32_t l0cMultiBatch = 0;
int32_t dbL0A = 1;
int32_t dbL0B = 1;
int32_t dbL0C = 1;
int32_t dbCub = 1;
int32_t finalML0 = 0;
int32_t finalKL0 = 0;
int32_t finalNL0 = 0;
int32_t finalLoadSize = INT_MAX;
float finalL0cUse = 0;
int32_t finalMul = 0;
int32_t finalMte1Loop = INT_MAX;
int32_t finalMte1Cycles = 0;
int32_t maxMk = 1;
int32_t maxNk = 1;
int32_t maxMn = 1;
int32_t maxAxisIdx = 0;
int32_t maxAxisNum = 0;
int32_t maxAxisPnt = 1;
int32_t maxN = 1;
int32_t dtypeBias = 0;
int32_t load2dTimes = 0;
int32_t l0cUsed = 0;
bool updateUsingMte1 = false;
void InitLoadStatus()
{
finalML0 = 0;
finalKL0 = 0;
finalNL0 = 0;
finalLoadSize = INT_MAX;
finalL0cUse = 0;
finalMul = 0;
finalMte1Loop = INT_MAX;
finalMte1Cycles = 0;
updateUsingMte1 = false;
}
};
struct CoreStatusPack {
CoreStatusPack()
{
cycle = INT_MAX;
loadSize = INT_MAX;
batch = 1;
batchDim = 1;
mDim = 1;
kDim = 1;
nDim = 1;
}
int32_t batch = 1;
int32_t m = 1;
int32_t k = 1;
int32_t n = 1;
int32_t batchDim = 1;
int32_t mDim = 1;
int32_t nDim = 1;
int32_t kDim = 1;
int32_t kAl1Factor = 1;
int32_t kBl1Factor = 1;
int32_t mSingleCore = 1;
int32_t nSingleCore = 1;
int32_t n0Max = 1;
int64_t cycle = 1;
int32_t loadSize = 1;
int32_t aL1FullLoadSize = 0;
int32_t bL1FullLoadSize = 0;
int64_t madCycle = 1;
int64_t repeatLoadSize = INT_MAX;
};
struct SingleCoreStatus {
L0StatusPack l0Status;
L1StatusPack l1Status;
};
struct L0Factors {
int32_t finalML0 = 0;
int32_t finalKL0 = 0;
int32_t finalNL0 = 0;
int32_t finalLoadSize = INT_MAX;
float finalL0cUse = 0;
int32_t finalMul = 0;
int32_t finalMte1Loop = INT_MAX;
int32_t finalMte1Cycles = 0;
};
struct MKNParasCombo {
int32_t parasCombo[10];
};
struct MatmulRunParas {
int64_t oriShapeM = 1;
int64_t oriShapeN = 1;
int64_t oriShapeKa = 1;
int64_t oriShapeKb = 1;
int64_t oriShapeAbatch = 1;
int64_t oriShapeBbatch = 1;
int32_t dtypeA = 0;
int32_t dtypeB = 0;
int32_t dtypeOut = 0;
int32_t dtypeBias = 0;
bool transA = false;
bool transB = false;
bool formatANd = false;
bool formatBNd = false;
bool formatOutNd = false;
bool biasFlag = false;
bool hf32Flag = 1;
int32_t batch32 = 1;
int32_t m32 = 1;
int32_t k32 = 1;
int32_t n32 = 1;
int32_t batch32A = 1;
int32_t batch32B = 1;
int32_t mMapped = 1;
int32_t kMapped = 1;
int32_t nMapped = 1;
int32_t batchMapped = 1;
bool nonFactorK = false;
};
#ifndef ASCC_STRUCT_L1TILINGTYPE
#define ASCC_STRUCT_L1TILINGTYPE
enum class L1TilingType : uint8_t { KAL1_16, KBL1_16, M_AL1, N_BL1 };
#endif
enum class MultiCoreScenario : uint8_t {
K_FULL_LOAD,
SPLIT_MN,
SPLIT_SMALL_MN,
OTHERS
};
struct DimCalculator {
int32_t batch = 1;
int32_t m = 1;
int32_t k = 1;
int32_t n = 1;
int32_t kNum = 1;
int32_t kBytes = 1;
int32_t batchDimFactor = 1;
int32_t mDimFactor = 1;
int32_t nDimFactor = 1;
int32_t kDimFactor = 1;
int32_t minLoadSize = 1;
int32_t coreUse = 1;
int32_t tmpCoreUse = 1;
int32_t loopNumToL0 = 1;
int32_t batchIdx = 0;
int32_t nIdx = 0;
int32_t batchDimCnt = 0;
int32_t mDimCnt = 0;
int32_t nDimCnt = 0;
int32_t kDimCnt = 0;
int32_t batchFactorCnt = 0;
int32_t oriAmatSize = 0;
int32_t oriBmatSize = 0;
int32_t amatSize = 0;
int32_t bmatSize = 0;
int32_t tmpAmatSize = 0;
int32_t tmpBmatSize = 0;
int32_t tmpLoadSize = 0;
int32_t totalLoadSize = 0;
int32_t tmpValue = 0;
int32_t finalValue = 0;
std::vector<int32_t> batchDimFactors;
std::vector<int32_t> mDimFactors;
std::vector<int32_t> nDimFactors;
std::vector<int32_t> kDimFactors;
bool initFlag = false;
bool bigPackage = false;
DimCalculator()
{
batchDimFactors.clear();
mDimFactors.clear();
nDimFactors.clear();
kDimFactors.clear();
nDimFactor = 1;
nDimFactor = 1;
mDimFactor = 1;
kDimFactor = 1;
minLoadSize = INT_MAX;
loopNumToL0 = INT_MAX;
bigPackage = false;
}
};
class DimFactor {
public:
DimFactor() : batch(1), m(1), k(1), n(1), group(1) {}
DimFactor(int32_t inBatch, int32_t inM, int32_t inK, int32_t inN, int32_t inGroup = 1)
: batch(inBatch), m(inM), k(inK), n(inN), group(inGroup)
{}
int32_t ReduceMul() const { return batch * m * k * n; }
void Init()
{
batch = 1;
m = 1;
k = 1;
n = 1;
group = 1;
}
bool IsValid() const { return group > 0 && batch > 0 && m > 0 && k > 0 && n > 0; }
int32_t batch = 1;
int32_t m = 1;
int32_t k = 1;
int32_t n = 1;
int32_t group = 1;
};
struct ComputeBaseBlock {
int32_t baseM;
int32_t baseN;
int32_t baseK;
};
struct ComputeIntensitySmallShape {
std::pair<int32_t, int32_t> dimFactor;
float computeCycle;
float avgIntensity;
float bandRatio;
float memoryTraffic;
ComputeBaseBlock baseBlock;
bool operator<(const ComputeIntensitySmallShape& rhs) const
{
if (std::fabs(rhs.memoryTraffic - memoryTraffic) > MEMORY_EPSILON * memoryTraffic) {
return memoryTraffic < rhs.memoryTraffic;
}
if (std::fabs(computeCycle - rhs.computeCycle) < 0.01f * rhs.computeCycle) {
return bandRatio > rhs.bandRatio;
}
return avgIntensity > rhs.avgIntensity;
}
};
struct ComputeIntensity {
std::pair<int32_t, int32_t> dimFactor;
float computeCycle;
float avgIntensity;
float bandRatio;
bool operator<(const ComputeIntensity& rhs) const
{
if (std::fabs(avgIntensity - rhs.avgIntensity) > MATMUL_EPSILON) {
return avgIntensity > rhs.avgIntensity;
}
if (std::fabs(computeCycle - rhs.computeCycle) < 0.01f * rhs.computeCycle) {
return bandRatio > rhs.bandRatio;
}
return computeCycle < rhs.computeCycle;
}
};
struct MemoryRatios {
float aMemoryRatio = 1.0f;
float bMemoryRatio = 1.0f;
explicit MemoryRatios(float a = 1.0f, float b = 1.0f)
{
aMemoryRatio = a;
bMemoryRatio = b;
}
};
struct BaseBlockIntensity {
int32_t count;
float computeCycle;
float memoryTraffic;
};
class MatmulTilingAlgorithm {
public:
explicit MatmulTilingAlgorithm(MatmulApiTilingBase* tilingIns);
~MatmulTilingAlgorithm() { tilingIns_ = nullptr; }
int64_t Process();
bool EnableL1BankConflictOptimise() const;
private:
bool CheckFinalParams(const CoreStatusPack& coreStatus) const;
bool CheckBaseMN() const;
int32_t GetBestValue(int32_t base) const;
int32_t GetC0Size() const;
int32_t GetIteratorOrder(
const SingleCoreStatus& singleCoreStatus, const int32_t singleCoreM, const int32_t singleCoreN,
const int32_t singleCoreK) const;
void GetL0StatusFromParasCombo(L0StatusPack& l0Status, int32_t* parasCombo) const;
void GetTwoFactors(int32_t (&res)[2], int32_t base, int32_t dim, int32_t maxNum = 32) const;
void SetResFactors(L0Factors& resFactors, const L0StatusPack& l0Status) const;
int32_t GetLoadSize(const CoreStatusPack& coreStatus, const L0StatusPack& l0Status) const;
void GetFinalMkn(
SingleCoreStatus& singleCoreStatus, const CoreStatusPack& coreStatus, const int32_t& k0,
const int32_t& majorDimFactor, const int32_t& minorDimFactor) const;
bool CheckK0Align(int32_t k0) const;
void GetL0FactorsCand(
L0Factors& resFactors, const CoreStatusPack& coreStatus, SingleCoreStatus& singleCoreStatus,
int32_t* parasCombo, const MatmulRunParas& param) const;
MKNParasCombo GetParasCombo(const int32_t& index, const MatmulRunParas& param) const;
void GetL0cDB(
const L0Factors (&resFactors)[L0PARAS_COMBO_LEN], const CoreStatusPack& coreStatus,
L0StatusPack& l0Status) const;
int32_t GetMxCurL1Size(const SingleCoreStatus& singleCoreStatus) const;
void GetL0Factors(
const std::string& opType, const MatmulRunParas& param, const CoreStatusPack& coreStatus,
SingleCoreStatus& singleCoreStatus) const;
void AdjustSparseL0Factors(SingleCoreStatus& singleCoreStatus) const;
void AdjustMxL0Factors(SingleCoreStatus& singleCoreStatus) const;
void UpdateBaseKForMxGemv(int32_t& baseK, SingleCoreStatus& singleCoreStatus) const;
void AdjustMxL1Factors(SingleCoreStatus& singleCoreStatus) const;
void FixMxScaleFactorByRange(uint8_t& factor, uint8_t maxFactor) const;
void FixMxScaleFactorByPosition(
uint8_t& scaleFactorM, uint8_t& scaleFactorN, uint8_t& scaleFactorKa, uint8_t& scaleFactorKb) const;
void GetMxScaleSize(int32_t& scaleA1Size, int32_t& scaleB1Size) const;
void GetMxScaleFactor(const SingleCoreStatus& singleCoreStatus, int32_t& mxTypePara) const;
void CheckL0DB(SingleCoreStatus& singleCoreStatus, const int32_t baseK) const;
void CheckL0DB(int32_t baseM, int32_t baseN, int32_t baseK, SingleCoreStatus& singleCoreStatus) const;
void GetMxUsedL1Size(
const SingleCoreStatus& singleCoreStatus, int32_t& dataUsedL1Size, int32_t& scaleUsedL1Size,
int32_t& biasUsedL1Size) const;
bool AdjustNBuffer33L0Factors(
const MatmulRunParas& param, const CoreStatusPack& coreStatus, SingleCoreStatus& singleCoreStatus) const;
bool AdjustNBuffer33L1Factors(const CoreStatusPack& coreStatus, SingleCoreStatus& singleCoreStatus) const;
bool CheckFixSplitInputs(int32_t singleCoreM) const;
void CalcBaseShape(
const SingleCoreStatus& singleCoreStatus, int32_t& baseM, int32_t& baseN, int32_t& baseK,
int32_t& reduceSize) const;
bool CheckL0ASize(int32_t singleCoreM, int32_t singleCoreK, int32_t& baseM, int32_t& baseK) const;
bool CheckL0BSize(int32_t singleCoreN, int32_t singleCoreK, int32_t& baseN, int32_t& baseK) const;
bool CheckL0CSize(int32_t singleCoreM, int32_t singleCoreN, int32_t& baseM, int32_t& baseN) const;
int32_t GetNBuffer33L1Size(const SingleCoreStatus& singleCoreStatus) const;
bool IsNeedAlign(bool isA) const;
void GetABL1Const(int32_t& aL1Const, int32_t& bL1Const, const L1StatusPack& l1Status) const;
int32_t GetL1Size(const L1StatusPack& l1Status, const L0StatusPack& l0Status) const;
int32_t CalL1MaxLen(
int32_t resL1Size, L1StatusPack& l1Status, const L0StatusPack& l0Status, const int32_t alignValue,
const L1TilingType axisName) const;
void GetNearestFactor(const int32_t& base, int32_t& factor, int32_t capValue = INT32_MAX) const;
void L1StatusAl1FullLoad(
const CoreStatusPack& coreStatus, const L0StatusPack& l0Status, L1StatusPack& l1Status, int32_t res[][7]) const;
void L1StatusBl1FullLoad(
const CoreStatusPack& coreStatus, const L0StatusPack& l0Status, L1StatusPack& l1Status, int32_t res[][7]) const;
void L1StatusBothFullLoad(
const CoreStatusPack& coreStatus, const L0StatusPack& l0Status, L1StatusPack& l1Status, int32_t res[][7]) const;
void NeitherFullLoadDb(
const CoreStatusPack& coreStatus, const L0StatusPack& l0Status, L1StatusPack& l1Status,
const int32_t& kbl1Db) const;
void NeitherFullLoadMN(
const CoreStatusPack& coreStatus, const L0StatusPack& l0Status, L1StatusPack& l1Status) const;
void NeitherFullLoadKforNZ(
const CoreStatusPack& coreStatus, const L0StatusPack& l0Status, L1StatusPack& l1Status) const;
void NeitherFullLoadKforND(
const CoreStatusPack& coreStatus, const L0StatusPack& l0Status, L1StatusPack& l1Status,
const int32_t& kMaxAxis) const;
void NeitherFullLoadK(const CoreStatusPack& coreStatus, const L0StatusPack& l0Status, L1StatusPack& l1Status) const;
void L1StatusNeitherFullLoad(
const CoreStatusPack& coreStatus, const L0StatusPack& l0Status, L1StatusPack& l1Status, int32_t res[][7]) const;
void GetL1Factors(
const std::string& opType, const MatmulRunParas& param, const CoreStatusPack& coreStatus,
const L0StatusPack& l0Status, L1StatusPack& l1Status) const;
bool CheckL1Size(int32_t amat, int32_t bmat, int32_t curBiasL1Size = 0) const;
void GetUsedSize(
int32_t& l1Size, int32_t& l0cSize, int32_t& ubSize, int32_t a1LengthCache, int32_t b1LengthCache) const;
void GetBankConflictSize(int32_t& length, bool isAMatrix) const;
void GetBankConflictSize(
const L1StatusPack& l1Status, const L0StatusPack& l0Status, int32_t& length, bool isAMatrix) const;
int32_t GetAL1UbSize(const L1StatusPack& l1Status, const L0StatusPack& l0Status) const;
int32_t GetBL1UbSize(const L1StatusPack& l1Status, const L0StatusPack& l0Status) const;
bool IsUbNd2Nz() const;
void GetTransLength(int32_t& transLength) const;
void SetDepthL1CacheUBParams(int32_t& a1LengthCache, int32_t& b1LengthCache) const;
void GetABL1KAlignValue(int32_t& kaAlignValue, int32_t& kbAlignValue) const;
bool CheckBaseMNKL1Size(SingleCoreStatus& singleCoreStatus) const;
void NonFactorMap(const std::string& opType, MatmulRunParas& param, DimCalculator& dimCalRes) const;
void UpdateDimCalculator(DimCalculator& dimCalRes) const;
void GetDimsHelper(
const DimFactor& dimFactor, CoreStatusPack& coreStatus, DimCalculator& dimCalRes, const MatmulRunParas& params);
void GetDims(
const std::string& opType, MatmulRunParas& params, CoreStatusPack& coreStatus, DimCalculator& dimCalRes);
bool PreProcessMiniShape(
const std::string& opType, CoreStatusPack& coreStatus, MatmulRunParas& params, const int32_t& coreNum,
bool splitKFlag) const;
float CalculateBlockCycles(int32_t baseM, int32_t baseN, int32_t baseK) const;
float CalculateMemoryTraffic(
int32_t baseM, int32_t baseN, int32_t baseK, float aMemoryRatio, float bMemoryRatio) const;
bool AlignSingleShape(
bool needAlign, int32_t orgShape, int32_t factor, int32_t alignSize, int32_t& singleShape) const;
ComputeIntensity CalcComputeIntensity(
const MatmulRunParas& params, const ComputeBaseBlock& baseBlock,
const std::pair<int32_t, int32_t>& factor) const;
ComputeIntensitySmallShape CalcComputeIntensitySmallShape(
const MatmulRunParas& params, const std::pair<int32_t, int32_t>& factor, ComputeBaseBlock& baseBlock) const;
void CalcMultiCoreDims(
const MatmulRunParas& params, const ComputeBaseBlock& baseBlock, CoreStatusPack& coreStatus,
DimCalculator& dimCalRes);
void CalcMultiCoreDimsPost(const MatmulRunParas& params, CoreStatusPack& coreStatus, DimCalculator& dimCalRes);
void CalcMultiCoreDimsSmallShape(
const MatmulRunParas& params, ComputeBaseBlock& baseBlock, CoreStatusPack& coreStatus,
DimCalculator& dimCalRes);
bool CalcNBuffer33Dims(
const MatmulRunParas& params, const ComputeBaseBlock& baseBlock, CoreStatusPack& coreStatus) const;
void UpdateBaseBlock(
const MatmulRunParas& params, const int32_t sm, const int32_t sn, ComputeBaseBlock& baseBlock) const;
std::vector<BaseBlockIntensity> CalcTotalCycleMemory(
const std::pair<int32_t, int32_t>& shapeM, const std::pair<int32_t, int32_t>& shapeN,
const ComputeBaseBlock& baseBlock, const float memoryRatio,
const MemoryRatios memoryRatios = MemoryRatios()) const;
void UpdateMultiCore(
const std::string& opType, const MatmulRunParas& params, CoreStatusPack& coreStatus,
const DimCalculator& dimCalRes) const;
void CalcLoadSize(
const DimFactor& dimFactor, const CoreStatusPack& coreStatus, DimCalculator& dimCalRes,
const MatmulRunParas& params) const;
void FillParam(MatmulRunParas& param);
bool IsInvalidFactor(int32_t factor) const;
void AddOptimalFactors(const std::string& opType, const MatmulRunParas& params, DimCalculator& dimCalRes) const;
int32_t LoopNumFromSingleCoreToL0(const CoreStatusPack& coreStatus, const DimFactor& dimFactor) const;
void GenDimsMapFactors(const std::string& opType, MatmulRunParas& params, DimCalculator& dimCalRes) const;
void UpdateBufferSize(const TilingPolicy policy, const CoreStatusPack& coreStatus) const;
bool UserPolicy(const TilingPolicy policy, const CoreStatusPack& coreStatus, const DimCalculator& dimCalRes) const;
void PreprocessL0DB();
void GetL0bAlign(std::vector<int32_t>& factors) const;
int32_t GetBigPackageCondition(
const CoreStatusPack& coreStatus, const DimCalculator& dimCalRes, const MatmulRunParas& params) const;
int UpdateDepthB1(const SingleCoreStatus& singleCoreStatus) const;
void GetSingleShape(
const CoreStatusPack& coreStatus, const MatmulRunParas& param, int32_t& singleCoreM, int32_t& singleCoreN,
int32_t& singleCoreK) const;
bool CheckSingleShape(int32_t singleCoreM, int32_t singleCoreN, int32_t singleCoreK) const;
MultiCoreScenario GetMultiCoreScenario(const MatmulRunParas& params) const;
ComputeBaseBlock GetMultiCoreBasicBlock(const MatmulRunParas& params) const;
int32_t GetSingleM() const;
int32_t GetSingleN() const;
int32_t GetSingleK() const;
float CalcBaseBlockBandRatio(int32_t mDim, int32_t nDim, const ComputeBaseBlock& baseBlock) const;
bool DoMultiCoreSplitMNTiling(const MatmulRunParas& params, CoreStatusPack& coreStatus, DimCalculator& dimCalRes);
void CalcL1Tiling(
const ComputeBaseBlock& baseBlock, int32_t& depthA1, int32_t& depthB1, int32_t& stepKa, int32_t& stepKb) const;
void UpdateStepK(const ComputeBaseBlock& baseBlock, int32_t& stepK) const;
bool NeedOutputAlign(int32_t m, int32_t n, int32_t k) const;
void UpdateUsedSize() const;
int64_t AdjustOuterProductL0Factor(SingleCoreStatus& singleCoreStatus) const;
L0StatusPack GetL0CoreStatus(const ComputeBaseBlock& baseBlock) const;
L1StatusPack GetL1CoreStatus(
const ComputeBaseBlock& baseBlock, int32_t depthA1, int32_t depthB1, int32_t stepKa, int32_t stepKb) const;
void UpdateShapeAndLayout() const;
void AdjustFloatL1Factor(const SingleCoreStatus& singleCoreStatus) const;
int64_t UpdateTiling(
const MatmulRunParas& param, const CoreStatusPack& coreStatus, SingleCoreStatus& singleCoreStatus) const;
int32_t GetC0Size(DataType dataType) const;
int32_t GetABaseHeightAlign(int32_t baseHeight) const;
int32_t GetABaseWidthAlign(int32_t baseWidth) const;
int32_t GetBBaseHeightAlign(int32_t baseHeight) const;
int32_t GetBBaseWidthAlign(int32_t baseWidth) const;
int32_t GetScaleABaseHeightAlign() const;
int32_t GetScaleABaseWidthAlign() const;
int32_t GetScaleBBaseHeightAlign() const;
int32_t GetScaleBBaseWidthAlign() const;
int32_t GetMatrixAByteSize() const;
int32_t GetMatrixBByteSize() const;
int32_t GetMatrixScaleAByteSize() const;
int32_t GetMatrixScaleBByteSize() const;
void CalABAndScaleABL1Space(
int32_t matrixByteSize, int32_t cacheNum, int32_t stepSize, uint32_t& curL1UpperHalfAddr,
uint32_t& curL1LowerHalfAddr) const;
void SetBaseMNK(const SingleCoreStatus& singleCoreStatus) const;
private:
MatmulApiTilingBase* tilingIns_ = nullptr;
bool enableSingleShape_ = false;
bool splitCoreFlag_ = false;
int32_t dbL0A_ = DB_ON;
int32_t dbL0B_ = DB_ON;
int32_t dbL0C_ = DB_OFF;
int32_t numOfBlock_ = 24;
};
}
#endif