* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file mx_matmul_impl.h
* \brief
*/
#if !defined(__ASCENDC_INCLUDE_INTERNAL_HEADERS__)
#pragma message( \
"impl/adv_api/detail/matmul/mx_matmul_impl.h is an internal header file and must not be used directly. Functions or variables defined in this file may be removed in the future. Please use \"#include \"adv_api/matmul/matmul.h\"\" and use public functions or variables defined in interface headers files.")
#define __ASCENDC_INCLUDE_INTERNAL_HEADERS__
#define __UNDEF_ASCENDC_INCLUDE_INTERNAL_HEADERS_DETAIL_MATMUL_MX_MATMUL_IMPL_H__
#endif
#ifndef IMPL_MATMUL_MX_MATMUL_IMPL_H
#define IMPL_MATMUL_MX_MATMUL_IMPL_H
#include "matmul_impl_base.h"
namespace AscendC {
template <
class A_TYPE, class B_TYPE, class C_TYPE, class BIAS_TYPE, const auto& MM_CFG, class MM_CB,
MATMUL_POLICY_TEMPLATE_OF(MATMUL_POLICY)>
class MatmulImpl<
A_TYPE, B_TYPE, C_TYPE, BIAS_TYPE, MM_CFG, MM_CB, MATMUL_POLICY,
enable_if_t<A_TYPE::layout == LayoutMode::NONE && isMxMatmul<A_TYPE, B_TYPE>>>
: public MatmulImplBase<A_TYPE, B_TYPE, C_TYPE, BIAS_TYPE, MM_CFG, MM_CB, MATMUL_POLICY>,
MATMUL_IMPORT_MODULE(CopyCubeInScaleA),
MATMUL_IMPORT_MODULE(CopyCubeInScaleB),
MATMUL_IMPORT_MODULE(CubeInBufferScaleA),
MATMUL_IMPORT_MODULE(CubeInBufferScaleB),
MATMUL_IMPORT_MODULE_PRIVATE(MatmulTensorInfoScaleA),
MATMUL_IMPORT_MODULE_PRIVATE(MatmulTensorInfoScaleB),
MATMUL_IMPORT_MODULE_PRIVATE(CopyCubeInParamsScaleA),
MATMUL_IMPORT_MODULE_PRIVATE(CopyCubeInParamsScaleB),
MATMUL_IMPORT_MODULE_PRIVATE(DataCopyUtilsScaleA),
MATMUL_IMPORT_MODULE_PRIVATE(DataCopyUtilsScaleB),
MATMUL_IMPORT_MODULE_PRIVATE(DataCopyWrapperScaleA),
MATMUL_IMPORT_MODULE_PRIVATE(DataCopyWrapperScaleB) {
private:
using SrcScaleT = fp8_e8m0_t;
using DstT = typename C_TYPE::T;
using IMPL = MatmulImpl<A_TYPE, B_TYPE, C_TYPE, BIAS_TYPE, MM_CFG, MM_CB, MATMUL_POLICY>;
public:
MATMUL_ALLOW_USING(CubeInBufferA);
MATMUL_ALLOW_USING(CubeInBufferB);
MATMUL_ALLOW_USING(CopyCubeInScaleA);
MATMUL_ALLOW_USING(CopyCubeInScaleB);
MATMUL_ALLOW_USING(CubeInBufferScaleA);
MATMUL_ALLOW_USING(CubeInBufferScaleB);
MATMUL_ALLOW_USING(Scheduler);
MATMUL_ALLOW_USING_PRIVATE(MatmulTensorInfoA);
MATMUL_ALLOW_USING_PRIVATE(MatmulTensorInfoB);
MATMUL_ALLOW_USING_PRIVATE(MatmulTensorInfoScaleA);
MATMUL_ALLOW_USING_PRIVATE(MatmulTensorInfoScaleB);
MATMUL_ALLOW_USING_PRIVATE(CopyCubeInParamsA);
MATMUL_ALLOW_USING_PRIVATE(CopyCubeInParamsB);
MATMUL_ALLOW_USING_PRIVATE(CopyCubeInParamsScaleA);
MATMUL_ALLOW_USING_PRIVATE(CopyCubeInParamsScaleB);
MATMUL_ALLOW_USING_PRIVATE(DataCopyUtilsA);
MATMUL_ALLOW_USING_PRIVATE(DataCopyUtilsB);
MATMUL_ALLOW_USING_PRIVATE(DataCopyUtilsScaleA);
MATMUL_ALLOW_USING_PRIVATE(DataCopyUtilsScaleB);
MATMUL_ALLOW_USING_PRIVATE(DataCopyWrapperA);
MATMUL_ALLOW_USING_PRIVATE(DataCopyWrapperB);
MATMUL_ALLOW_USING_PRIVATE(DataCopyWrapperScaleA);
MATMUL_ALLOW_USING_PRIVATE(DataCopyWrapperScaleB);
private:
MATMUL_USE_MODULE(CubeInBufferA);
MATMUL_USE_MODULE(CubeInBufferB);
MATMUL_USE_MODULE(CopyCubeInScaleA);
MATMUL_USE_MODULE(CopyCubeInScaleB);
MATMUL_USE_MODULE(CubeInBufferScaleA);
MATMUL_USE_MODULE(CubeInBufferScaleB);
MATMUL_USE_MODULE(Scheduler);
MATMUL_USE_MODULE(MatmulCrossCoreSync);
public:
template <InputTypeTag TAG>
using CubeInBuffer = typename ConditionalMulti<
TAG == InputTypeTag::A || TAG == InputTypeTag::B, TAG == InputTypeTag::A || TAG == InputTypeTag::scaleA,
CubeInBufferA, CubeInBufferB, CubeInBufferScaleA, CubeInBufferScaleB>::type;
template <InputTypeTag TAG>
using CopyCubeInParams = typename ConditionalMulti<
TAG == InputTypeTag::A || TAG == InputTypeTag::B, TAG == InputTypeTag::A || TAG == InputTypeTag::scaleA,
CopyCubeInParamsA, CopyCubeInParamsB, CopyCubeInParamsScaleA, CopyCubeInParamsScaleB>::type;
template <InputTypeTag TAG>
using MatmulTensorInfo = typename ConditionalMulti<
TAG == InputTypeTag::A || TAG == InputTypeTag::B, TAG == InputTypeTag::A || TAG == InputTypeTag::scaleA,
MatmulTensorInfoA, MatmulTensorInfoB, MatmulTensorInfoScaleA, MatmulTensorInfoScaleB>::type;
template <InputTypeTag TAG>
using DataCopyUtils = typename ConditionalMulti<
TAG == InputTypeTag::A || TAG == InputTypeTag::B, TAG == InputTypeTag::A || TAG == InputTypeTag::scaleA,
DataCopyUtilsA, DataCopyUtilsB, DataCopyUtilsScaleA, DataCopyUtilsScaleB>::type;
template <InputTypeTag TAG>
using DataCopyWrapper = typename ConditionalMulti<
TAG == InputTypeTag::A || TAG == InputTypeTag::B, TAG == InputTypeTag::A || TAG == InputTypeTag::scaleA,
DataCopyWrapperA, DataCopyWrapperB, DataCopyWrapperScaleA, DataCopyWrapperScaleB>::type;
using BASE_MODULE = MatmulImplBase<A_TYPE, B_TYPE, C_TYPE, BIAS_TYPE, MM_CFG, MM_CB, MATMUL_POLICY>;
__aicore__ inline MatmulImpl() {}
__aicore__ inline void Init(const TCubeTiling* __restrict cubeTiling, TPipe* tpipe = nullptr)
{
static_assert(
!PhyPosIsL1(A_TYPE::pos) || (PhyPosIsGM(A_TYPE::srcPos) || PhyPosIsUB(A_TYPE::srcPos)),
"A_TYPE::srcPos only support GM or VECOUT when A_TYPE::pos is TSCM.");
static_assert(
!PhyPosIsL1(A_TYPE::scalePosition) || (PhyPosIsGM(A_TYPE::srcScalePos) || PhyPosIsUB(A_TYPE::srcScalePos)),
"A_TYPE::srcScalePos only support GM or VECOUT when A_TYPE::scalePosition is TSCM.");
static_assert(
!PhyPosIsL1(B_TYPE::pos) || (PhyPosIsGM(B_TYPE::srcPos) || PhyPosIsUB(B_TYPE::srcPos)),
"B_TYPE::srcPos only support GM or VECOUT when B_TYPE::pos is TSCM.");
static_assert(
!PhyPosIsL1(B_TYPE::scalePosition) || (PhyPosIsGM(B_TYPE::srcScalePos) || PhyPosIsUB(B_TYPE::srcScalePos)),
"B_TYPE::srcScalePos only support GM or VECOUT when B_TYPE::scalePosition is TSCM.");
BASE_MODULE::Init(cubeTiling, tpipe);
}
__aicore__ inline void SetTensorScaleA(const GlobalTensor<SrcScaleT>& gm, bool isTransposeScaleA = false)
{
MATMUL_MODULE(CopyCubeInScaleA)->SetInput(gm, isTransposeScaleA);
MATMUL_MODULE(Scheduler)->Reset();
}
__aicore__ inline void SetTensorScaleA(const LocalTensor<SrcScaleT>& leftMatrix, bool isTransposeScaleA = false)
{
MATMUL_MODULE(CopyCubeInScaleA)->SetInput(leftMatrix, isTransposeScaleA);
MATMUL_MODULE(Scheduler)->Reset();
}
__aicore__ inline void SetTensorScaleB(const GlobalTensor<SrcScaleT>& gm, bool isTransposeScaleB = true)
{
MATMUL_MODULE(CopyCubeInScaleB)->SetInput(gm, isTransposeScaleB);
MATMUL_MODULE(Scheduler)->Reset();
}
__aicore__ inline void SetTensorScaleB(const LocalTensor<SrcScaleT>& rightMatrix, bool isTransposeScaleB = true)
{
MATMUL_MODULE(CopyCubeInScaleB)->SetInput(rightMatrix, isTransposeScaleB);
MATMUL_MODULE(Scheduler)->Reset();
}
template <bool sync = true>
__aicore__ inline void IterateAll(
const GlobalTensor<DstT>& gm, uint8_t enAtomic = 0, bool enSequentialWrite = false, bool waitIterateAll = false,
bool fakeMsg = false)
{
ASCENDC_ASSERT((!ToMatmulConfig(MM_CFG).isPartialOutput), {
KERNEL_LOG(KERNEL_ERROR, "IterateAll is not supported for PartialOutput.");
});
while (BASE_MODULE::Iterate()) {
BASE_MODULE::GetTensorC(gm, enAtomic);
}
}
template <bool sync = true>
__aicore__ inline void IterateAll(const LocalTensor<DstT>& ubCmatrix, uint8_t enAtomic = 0)
{
ASCENDC_ASSERT((!ToMatmulConfig(MM_CFG).isPartialOutput), {
KERNEL_LOG(KERNEL_ERROR, "IterateAll is not supported for PartialOutput.");
});
int64_t dstOffset = 0;
while (BASE_MODULE::Iterate(false, ubCmatrix[dstOffset])) {
if constexpr (PhyPosIsL0C(C_TYPE::pos)) {
dstOffset += MATMUL_MODULE(Scheduler)->GetL0cOffset();
}
BASE_MODULE::GetTensorC(ubCmatrix, enAtomic);
}
}
__aicore__ inline MatrixL1Addr GetMatrixL1Addr()
{
struct MatrixL1Addr matrixL1Addr;
matrixL1Addr = BASE_MODULE::GetMatrixL1Addr();
#if defined(__NPU_ARCH__) && __NPU_ARCH__ == 3510
if constexpr (PhyMxScalePosIsUB<A_TYPE>()) {
matrixL1Addr.l1aScaleAddr = MATMUL_MODULE(CubeInBufferScaleA)->GetBufferHeadAddr();
}
if constexpr (PhyMxScalePosIsUB<B_TYPE>()) {
matrixL1Addr.l1bScaleAddr = MATMUL_MODULE(CubeInBufferScaleB)->GetBufferHeadAddr();
}
#endif
return matrixL1Addr;
}
__aicore__ inline void SetIntraScaleAId(uint8_t intraId)
{
MATMUL_MODULE(MatmulCrossCoreSync)->SetIntraScaleAId(intraId);
}
__aicore__ inline void SetIntraScaleBId(uint8_t intraId)
{
MATMUL_MODULE(MatmulCrossCoreSync)->SetIntraScaleBId(intraId);
}
friend __aicore__ inline void KfcSetIntraScaleAId(
MatmulImpl<A_TYPE, B_TYPE, C_TYPE, BIAS_TYPE, MM_CFG, MM_CB, MATMUL_POLICY>& mm, uint8_t intraId)
{
mm.SetIntraScaleAId(intraId);
}
friend __aicore__ inline void KfcSetIntraScaleBId(
MatmulImpl<A_TYPE, B_TYPE, C_TYPE, BIAS_TYPE, MM_CFG, MM_CB, MATMUL_POLICY>& mm, uint8_t intraId)
{
mm.SetIntraScaleBId(intraId);
}
};
}
#endif
#if defined(__UNDEF_ASCENDC_INCLUDE_INTERNAL_HEADERS_DETAIL_MATMUL_MX_MATMUL_IMPL_H__)
#undef __ASCENDC_INCLUDE_INTERNAL_HEADERS__
#undef __UNDEF_ASCENDC_INCLUDE_INTERNAL_HEADERS_DETAIL_MATMUL_MX_MATMUL_IMPL_H__
#endif