* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file mx_matmul_utils.h
* \brief
*/
#if !defined(__ASCENDC_INCLUDE_INTERNAL_HEADERS__)
#pragma message( \
"impl/adv_api/detail/matmul/utils/mx_matmul_utils.h is an internal header file and must not be used directly. Functions or variables defined in this file may be removed in the future. Please use \"#include \"adv_api/matmul/matmul.h\"\" and use public functions or variables defined in interface headers files.")
#define __ASCENDC_INCLUDE_INTERNAL_HEADERS__
#define __UNDEF_ASCENDC_INCLUDE_INTERNAL_HEADERS_DETAIL_MATMUL_UTILS_MX_MATMUL_UTILS_H__
#endif
#ifndef IMPL_MATMUL_UTILS_MX_MATMUL_UTILS_H
#define IMPL_MATMUL_UTILS_MX_MATMUL_UTILS_H
#include "matmul_type_def.h"
#include "../feature_trait/matmul_feature_trait.h"
namespace AscendC {
template <typename T, typename U>
constexpr bool IsSameTypeV = AscendC::IsSameType<T, U>::value;
template <typename T, typename... Others>
struct IsTypeOneOf {
static constexpr bool value = false;
};
template <typename T, typename First, typename... Others>
struct IsTypeOneOf<T, First, Others...> {
static constexpr bool value = IsSameTypeV<T, First> || IsTypeOneOf<T, Others...>::value;
};
template <typename T, typename... Others>
constexpr bool IsTypeOneOfV = IsTypeOneOf<T, Others...>::value;
template <typename T>
struct GetMmDstType {
using Type = T;
};
template <typename T, bool isMxType = false>
struct GetL0DataType {
using Type = T;
};
#if defined(__NPU_ARCH__) && __NPU_ARCH__ == 3510
template <>
struct GetMmDstType<fp8_e4m3fn_t> {
using Type = float;
};
template <>
struct GetMmDstType<fp8_e5m2_t> {
using Type = float;
};
template <>
struct GetMmDstType<hifloat8_t> {
using Type = float;
};
template <>
struct GetMmDstType<fp4x2_e2m1_t> {
using Type = float;
};
template <>
struct GetMmDstType<fp4x2_e1m2_t> {
using Type = float;
};
template <>
struct GetL0DataType<fp8_e5m2_t, true> {
using Type = AscendC::mx_fp8_e5m2_t;
};
template <>
struct GetL0DataType<fp8_e5m2_t, false> {
using Type = fp8_e5m2_t;
};
template <>
struct GetL0DataType<fp8_e4m3fn_t, true> {
using Type = AscendC::mx_fp8_e4m3_t;
};
template <>
struct GetL0DataType<fp8_e4m3fn_t, false> {
using Type = fp8_e4m3fn_t;
};
#endif
template <typename SrcT>
__aicore__ inline constexpr static int32_t AuxGetC0Size()
{
if (sizeof(SrcT) == sizeof(float)) {
return Impl::B32_C0SIZE;
}
#if defined(__NPU_ARCH__) && __NPU_ARCH__ == 3510
else if (IsTypeOneOfV<SrcT, uint8_t, int8_t, hifloat8_t, fp8_e4m3fn_t, fp8_e5m2_t, fp8_e8m0_t>) {
return Impl::B8_C0SIZE;
} else if (IsTypeOneOfV<SrcT, int4b_t, fp4x2_e1m2_t, fp4x2_e2m1_t>) {
return Impl::B4_C0SIZE;
}
#else
else if (IsSameType<SrcT, int8_t>::value) {
return Impl::B8_C0SIZE;
} else if (IsSameType<SrcT, int4b_t>::value) {
return Impl::B4_C0SIZE;
}
#endif
return Impl::B16_C0SIZE;
}
template <typename SrcT>
__aicore__ inline constexpr bool IsSupportB32()
{
if (IsTypeOneOfV<SrcT, int32_t, float>) {
return true;
}
return false;
}
template <typename SrcT>
__aicore__ inline constexpr bool IsSupportB8()
{
if (IsSameTypeV<SrcT, int8_t>) {
return true;
}
#if defined(__NPU_ARCH__) && __NPU_ARCH__ == 3510
if (IsTypeOneOfV<SrcT, hifloat8_t, fp8_e4m3fn_t, fp8_e5m2_t>) {
return true;
}
#endif
return false;
}
template <typename SrcT>
__aicore__ inline constexpr bool IsSupportB4()
{
if (IsSameTypeV<SrcT, int4b_t>) {
return true;
}
#if defined(__NPU_ARCH__) && __NPU_ARCH__ == 3510
if (IsTypeOneOfV<SrcT, fp4x2_e1m2_t, fp4x2_e2m1_t>) {
return true;
}
#endif
return false;
}
template <typename SrcT>
__aicore__ inline constexpr bool IsSupportMxFp4()
{
#if defined(__NPU_ARCH__) && __NPU_ARCH__ == 3510
if (IsTypeOneOfV<SrcT, fp4x2_e1m2_t, fp4x2_e2m1_t>) {
return true;
}
#endif
return false;
}
template <typename SrcT>
__aicore__ inline constexpr bool IsSupportMxFp8()
{
#if defined(__NPU_ARCH__) && __NPU_ARCH__ == 3510
if (IsTypeOneOfV<SrcT, fp8_e4m3fn_t, fp8_e5m2_t>) {
return true;
}
#endif
return false;
}
template <typename T>
__aicore__ inline constexpr static bool IsNeedC0Align()
{
return IsSupportB8<T>() || IsSupportB4<T>();
}
#if defined(__NPU_ARCH__) && __NPU_ARCH__ == 3510
constexpr uint8_t INTRA_MODE = 4;
template <typename INPUT_TYPE>
__aicore__ constexpr bool PhyMxScalePosIsL1()
{
if constexpr (HasScalePosition<INPUT_TYPE>::value) {
return PhyPosIsL1(INPUT_TYPE::scalePosition);
}
return false;
}
template <typename INPUT_TYPE>
__aicore__ constexpr bool PhyMxScalePosIsUB()
{
if constexpr (HasScalePosition<INPUT_TYPE>::value) {
return PhyPosIsUB(INPUT_TYPE::scalePosition);
}
return false;
}
template <typename INPUT_TYPE>
__aicore__ constexpr bool PhyMxScalePosIsGM()
{
if constexpr (HasScalePosition<INPUT_TYPE>::value) {
return PhyPosIsGM(INPUT_TYPE::scalePosition);
}
return false;
}
#endif
template <typename T>
__aicore__ constexpr int32_t GetBitSize()
{
if constexpr (std::is_arithmetic<T>::value) {
return sizeof(T) * ONE_BYTE_BIT_SIZE;
}
if constexpr (IsSameTypeV<T, AscendC::int4b_t>) {
return ONE_BYTE_BIT_SIZE / 2;
}
#if defined(__NPU_ARCH__) && __NPU_ARCH__ == 3510
if constexpr (IsTypeOneOfV<T, fp8_e8m0_t, hifloat8_t, fp8_e4m3fn_t, fp8_e5m2_t>) {
return ONE_BYTE_BIT_SIZE;
}
if constexpr (IsTypeOneOfV<T, fp4x2_e2m1_t, fp4x2_e1m2_t>) {
return ONE_BYTE_BIT_SIZE / 2;
}
#endif
#if __NPU_ARCH__ == 5102
if (IsSameTypeV<T, AscendC::int2b_t>) {
return ONE_BYTE_BIT_SIZE / 4;
}
#endif
return ONE_BYTE_BIT_SIZE * 2;
}
template <typename T>
constexpr bool IsScaleTransWithInlv =
(HasScalePosition<T>::value && PhyPosIsGM(T::pos) && (T::format == CubeFormat::ND) && PhyPosIsL1(T::scalePosition));
template <typename A_TYPE, typename B_TYPE, const auto& MM_CFG>
__aicore__ inline constexpr bool IsL1BNeedTrans()
{
if constexpr (!Impl::Detail::MatmulFeatureTrait<MM_CFG>::IsMmadInstrSupportAntiQuant()) {
if constexpr (GetBitSize<typename B_TYPE::T>() == GetBitSize<typename A_TYPE::T>()) {
return false;
} else {
return true;
}
}
return false;
}
template <typename A_TYPE, typename B_TYPE, const auto& MM_CFG>
__aicore__ inline constexpr auto GetTransBDataType()
{
if constexpr (HasScalePosition<A_TYPE>::value) {
B_TYPE mxBType;
return mxBType;
}
#if __NPU_ARCH__ == 5102
else if constexpr (
DecompMode(MM_CFG) == DecompressionMode::DECOMP_1bitTo4bit ||
DecompMode(MM_CFG) == DecompressionMode::DECOMP_2bitTo4bit) {
MatmulType<TPosition::GM, CubeFormat::NZ, int4b_t> bType;
return bType;
} else if constexpr (DecompMode(MM_CFG) == DecompressionMode::DECOMP_4bitTo8bit) {
MatmulType<TPosition::GM, CubeFormat::NZ, int8_t> bType;
return bType;
}
#else
else if constexpr (IsL1BNeedTrans<A_TYPE, B_TYPE, MM_CFG>()) {
A_TYPE aType;
return aType;
}
#endif
else {
B_TYPE bType;
return bType;
}
}
template <typename INPUT_TYPE>
__aicore__ inline constexpr bool IsScaleTag()
{
return INPUT_TYPE::TAG == InputTypeTag::scaleA || INPUT_TYPE::TAG == InputTypeTag::scaleB;
}
template <typename INPUT_TYPE>
__aicore__ inline constexpr bool InputPhyPosIsGM()
{
if constexpr (IsScaleTag<INPUT_TYPE>()) {
return PhyPosIsGM(INPUT_TYPE::scalePosition);
} else {
return PhyPosIsGM(INPUT_TYPE::pos);
}
}
template <typename INPUT_TYPE>
__aicore__ inline constexpr bool InputPhyPosIsL1()
{
if constexpr (IsScaleTag<INPUT_TYPE>()) {
return PhyPosIsL1(INPUT_TYPE::scalePosition);
} else {
return PhyPosIsL1(INPUT_TYPE::pos);
}
}
template <typename INPUT_TYPE>
__aicore__ inline constexpr bool InputPhyPosIsUB()
{
if constexpr (IsScaleTag<INPUT_TYPE>()) {
return PhyPosIsUB(INPUT_TYPE::scalePosition);
} else {
return PhyPosIsUB(INPUT_TYPE::pos);
}
}
#if defined(__NPU_ARCH__) && __NPU_ARCH__ == 3510
template <typename T>
constexpr bool SupportMXFP8 = IsTypeOneOfV<T, fp8_e4m3fn_t, fp8_e5m2_t, fp4x2_e2m1_t, fp4x2_e1m2_t>;
#else
template <typename T>
constexpr bool SupportMXFP8 = false;
#endif
template <typename AType, typename BType, const auto& MM_CFG>
constexpr bool IsMxDisableUnitFlag =
(EnUnitFlag(MM_CFG) && HasScalePosition<AType>::value &&
(AType::isTrans || !BType::isTrans || IsStaticPaddingEnable(MM_CFG)) && SupportMXFP8<typename AType::T>);
}
#endif
#if defined(__UNDEF_ASCENDC_INCLUDE_INTERNAL_HEADERS_DETAIL_MATMUL_UTILS_MX_MATMUL_UTILS_H__)
#undef __ASCENDC_INCLUDE_INTERNAL_HEADERS__
#undef __UNDEF_ASCENDC_INCLUDE_INTERNAL_HEADERS_DETAIL_MATMUL_UTILS_MX_MATMUL_UTILS_H__
#endif