* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file matmul_constant_tiling_shape_utils.h
* \brief
*/
#ifndef IMPL_MATMUL_TILING_MATMUL_CONSTANT_TILING_SHAPE_UTILS_H
#define IMPL_MATMUL_TILING_MATMUL_CONSTANT_TILING_SHAPE_UTILS_H
#include "matmul_constant_tiling_struct.h"
#include "matmul_constant_tiling_utils.h"
namespace AscendC {
namespace Impl {
constexpr int32_t SCALE_K_SIZE = 32;
constexpr int32_t MIN_MX_PARAM = 0x01010101;
constexpr int32_t MX_L1_BUFFER_NUM = 4;
constexpr uint32_t SCALE_FACTOR_MAX_VALUE = 127;
}
struct MxScaleStatus {
uint8_t scaleFactorKa;
uint8_t scaleFactorKb;
uint8_t scaleFactorM;
uint8_t scaleFactorN;
int32_t mxTypePara;
};
template <typename A_TYPE>
__aicore__ constexpr int32_t GetScaleAL1Size(const L1Status& l1Status, const MatmulConfig& mmCFG)
{
int32_t curScaleA1Size = 0;
int32_t kL0 = GetKL0<A_TYPE>(mmCFG);
if constexpr (PhyPosIsL1(A_TYPE::scalePosition)) {
curScaleA1Size = 0;
} else if constexpr (PhyPosIsUB(A_TYPE::scalePosition)) {
curScaleA1Size =
mmCFG.singleCoreM * CeilNoLog<int32_t>(mmCFG.singleCoreK, Impl::MX_BASEK_FACTOR) * Impl::ALIGN_TWO;
} else {
curScaleA1Size = l1Status.dbAL1 * l1Status.mAL1 * mmCFG.basicM * CeilNoLog<int32_t>(l1Status.kAL1, kL0) *
CeilNoLog<int32_t>(mmCFG.basicK, Impl::C0_BYTE_SIZE) * GetBitSize<fp8_e8m0_t>() /
ONE_BYTE_BIT_SIZE;
}
return curScaleA1Size;
}
template <typename B_TYPE>
__aicore__ constexpr int32_t GetScaleBL1Size(const L1Status& l1Status, const MatmulConfig& mmCFG)
{
int32_t curScaleB1Size = 0;
int32_t kL0 = GetKL0<B_TYPE>(mmCFG);
if constexpr (PhyPosIsL1(B_TYPE::scalePosition)) {
curScaleB1Size = 0;
} else if constexpr (PhyPosIsUB(B_TYPE::scalePosition)) {
curScaleB1Size =
mmCFG.singleCoreN * CeilNoLog<int32_t>(mmCFG.singleCoreK, Impl::MX_BASEK_FACTOR) * Impl::ALIGN_TWO;
} else {
curScaleB1Size = l1Status.dbBL1 * l1Status.nBL1 * mmCFG.basicN * CeilNoLog<int32_t>(l1Status.kBL1, kL0) *
CeilNoLog<int32_t>(mmCFG.basicK, Impl::C0_BYTE_SIZE) * GetBitSize<fp8_e8m0_t>() /
ONE_BYTE_BIT_SIZE;
}
return curScaleB1Size;
}
__aicore__ constexpr int32_t FixMxScaleFactor(int32_t factor, int32_t maxFactor)
{
factor = factor < maxFactor ? factor : maxFactor;
factor = factor > 1 ? factor : 1;
factor = factor < Impl::SCALE_FACTOR_MAX_VALUE ? factor : Impl::SCALE_FACTOR_MAX_VALUE;
return factor;
}
__aicore__ constexpr int32_t GetScaleABaseHeightAlign(const MatmulApiStaticTiling& tiling)
{
return Align<int32_t>(tiling.baseM, GetReduceC0Size<fp8_e8m0_t>());
}
__aicore__ constexpr int32_t GetScaleABaseWidthAlign(const MatmulApiStaticTiling& tiling)
{
return CeilNoLog<int32_t>(tiling.baseK, Impl::SCALE_K_SIZE);
}
__aicore__ constexpr int32_t GetScaleBBaseHeightAlign(const MatmulApiStaticTiling& tiling)
{
return Align<int32_t>(CeilNoLog<int32_t>(tiling.baseK, Impl::SCALE_K_SIZE), GetReduceC0Size<fp8_e8m0_t>());
}
template <typename B_TYPE>
__aicore__ constexpr int32_t GetScaleBBaseWidthAlign(const MatmulApiStaticTiling& tiling)
{
if (B_TYPE::isScaleTrans == false) {
return Align<int32_t>(tiling.baseN, GetReduceC0Size<fp8_e8m0_t>());
} else {
return tiling.baseN;
}
}
template <typename A_TYPE>
__aicore__ constexpr int32_t GetMatrixScaleAByteSize(const MatmulApiStaticTiling& tiling)
{
if constexpr (PhyPosIsUB(A_TYPE::scalePosition)) {
return Align<int32_t>(tiling.singleCoreM, Impl::HW_C0) *
Align<int32_t>(CeilNoLog<int32_t>(tiling.singleCoreK, Impl::SCALE_K_SIZE), Impl::C0_BYTE_SIZE);
} else if constexpr (PhyPosIsGM(A_TYPE::scalePosition)) {
return GetScaleABaseHeightAlign(tiling) * GetScaleABaseWidthAlign(tiling);
} else {
return 0;
}
}
template <typename B_TYPE>
__aicore__ constexpr int32_t GetMatrixScaleBByteSize(const MatmulApiStaticTiling& tiling)
{
if constexpr (PhyPosIsUB(B_TYPE::scalePosition)) {
return Align<int32_t>(CeilNoLog<int32_t>(tiling.singleCoreK, Impl::SCALE_K_SIZE), Impl::HW_C0) *
Align<int32_t>(tiling.singleCoreN, Impl::C0_BYTE_SIZE);
} else if constexpr (PhyPosIsGM(B_TYPE::scalePosition)) {
return GetScaleBBaseHeightAlign(tiling) * GetScaleBBaseWidthAlign<B_TYPE>(tiling);
} else {
return 0;
}
}
template <typename A_TYPE, typename B_TYPE>
__aicore__ constexpr void GetMxScaleSize(const MatmulApiStaticTiling& tiling, int& scaleA1Size, int& scaleB1Size)
{
if constexpr (PhyPosIsL1(A_TYPE::scalePosition)) {
scaleA1Size = Align<int32_t>(tiling.singleCoreM, Impl::C0_BYTE_SIZE) *
(CeilNoLog<int32_t>(tiling.singleCoreK, Impl::MX_BASEK_FACTOR) * Impl::ALIGN_TWO);
} else {
scaleA1Size = tiling.stepKa * tiling.stepM *
(GetMatrixScaleAByteSize<A_TYPE>(tiling) * GetBitSize<fp8_e8m0_t>() / ONE_BYTE_BIT_SIZE);
}
if constexpr (PhyPosIsL1(B_TYPE::scalePosition)) {
scaleB1Size = Align<int32_t>(tiling.singleCoreN, Impl::C0_BYTE_SIZE) *
(CeilNoLog<int32_t>(tiling.singleCoreK, Impl::MX_BASEK_FACTOR) * Impl::ALIGN_TWO);
} else {
scaleB1Size = tiling.stepKb * tiling.stepN *
(GetMatrixScaleBByteSize<B_TYPE>(tiling) * GetBitSize<fp8_e8m0_t>() / ONE_BYTE_BIT_SIZE);
}
}
template <typename A_TYPE, typename B_TYPE, typename BIAS_TYPE>
__aicore__ constexpr MxScaleStatus GetMxScaleFactor(const MatmulApiStaticTiling& tiling, int32_t l1Size)
{
MxScaleStatus mxScaleFactor{1, 1, 1, 1, 0};
int remainedL1BufferSize =
(l1Size - GetL1UsedSize<A_TYPE, B_TYPE, BIAS_TYPE>(tiling, mxScaleFactor)) / Impl::MX_L1_BUFFER_NUM;
int kStep = CeilNoLog<int32_t>(tiling.singleCoreK, tiling.baseK);
int scaleA1Size = 0;
int scaleB1Size = 0;
GetMxScaleSize<A_TYPE, B_TYPE>(tiling, scaleA1Size, scaleB1Size);
GetMxScaleSize<A_TYPE, B_TYPE>(tiling, scaleA1Size, scaleB1Size);
int oriScaleFactorKa = remainedL1BufferSize / scaleA1Size + 1;
int maxScaleFactorKa = CeilNoLog<int32_t>(kStep, tiling.stepKa);
mxScaleFactor.scaleFactorKa = FixMxScaleFactor(oriScaleFactorKa, maxScaleFactorKa);
int oriScaleFactorKb = remainedL1BufferSize / scaleB1Size + 1;
int maxScaleFactorKb = CeilNoLog<int32_t>(kStep, tiling.stepKb);
mxScaleFactor.scaleFactorKb = FixMxScaleFactor(oriScaleFactorKb, maxScaleFactorKb);
if (mxScaleFactor.scaleFactorKa == maxScaleFactorKa) {
int mStep = CeilNoLog<int32_t>(tiling.singleCoreM, tiling.baseM);
int oriScaleFactorM = remainedL1BufferSize / (mxScaleFactor.scaleFactorKa * scaleA1Size);
int maxScaleFactorM = CeilNoLog<int32_t>(mStep, tiling.stepM);
mxScaleFactor.scaleFactorM = FixMxScaleFactor(oriScaleFactorM, maxScaleFactorM);
}
if (mxScaleFactor.scaleFactorKb == maxScaleFactorKb) {
int nStep = CeilNoLog<int32_t>(tiling.singleCoreN, tiling.baseN);
int oriScaleFactorN = remainedL1BufferSize / (mxScaleFactor.scaleFactorKb * scaleB1Size);
int maxScaleFactorN = CeilNoLog<int32_t>(nStep, tiling.stepN);
mxScaleFactor.scaleFactorN = FixMxScaleFactor(oriScaleFactorN, maxScaleFactorN);
}
if constexpr (
(A_TYPE::format == CubeFormat::ND && A_TYPE::isTrans == true && A_TYPE::scalePosition == TPosition::TSCM) &&
(B_TYPE::format == CubeFormat::ND && B_TYPE::isTrans == false && B_TYPE::scalePosition == TPosition::TSCM)) {
mxScaleFactor.scaleFactorM = static_cast<uint8_t>(1);
mxScaleFactor.scaleFactorN = static_cast<uint8_t>(1);
mxScaleFactor.scaleFactorKa = static_cast<uint8_t>(1);
mxScaleFactor.scaleFactorKb = static_cast<uint8_t>(1);
} else {
if constexpr (A_TYPE::scalePosition == TPosition::TSCM) {
mxScaleFactor.scaleFactorM = static_cast<uint8_t>(1);
mxScaleFactor.scaleFactorKa = static_cast<uint8_t>(1);
}
if constexpr (B_TYPE::scalePosition == TPosition::TSCM) {
mxScaleFactor.scaleFactorN = static_cast<uint8_t>(1);
mxScaleFactor.scaleFactorKb = static_cast<uint8_t>(1);
}
}
mxScaleFactor.mxTypePara =
static_cast<int32_t>(static_cast<uint32_t>(mxScaleFactor.mxTypePara) | mxScaleFactor.scaleFactorKa);
mxScaleFactor.mxTypePara =
static_cast<int32_t>(static_cast<uint32_t>(mxScaleFactor.mxTypePara) | (mxScaleFactor.scaleFactorKb << 8U));
mxScaleFactor.mxTypePara =
static_cast<int32_t>(static_cast<uint32_t>(mxScaleFactor.mxTypePara) | (mxScaleFactor.scaleFactorM << 16U));
mxScaleFactor.mxTypePara =
static_cast<int32_t>(static_cast<uint32_t>(mxScaleFactor.mxTypePara) | (mxScaleFactor.scaleFactorN << 24U));
return mxScaleFactor;
}
}
#endif