* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file matmul_impl.h
* \brief
*/
#if !defined(__ASCENDC_INCLUDE_INTERNAL_HEADERS__)
#pragma message( \
"impl/adv_api/detail/matmul/matmul_impl.h is an internal header file and must not be used directly. Functions or variables defined in this file may be removed in the future. Please use \"#include \"adv_api/matmul/matmul.h\"\" and use public functions or variables defined in interface headers files.")
#define __ASCENDC_INCLUDE_INTERNAL_HEADERS__
#define __UNDEF_ASCENDC_INCLUDE_INTERNAL_HEADERS_DETAIL_MATMUL_MATMUL_IMPL_H__
#endif
#ifndef IMPL_MATMUL_MATMUL_IMPL_H
#define IMPL_MATMUL_MATMUL_IMPL_H
#include "matmul_impl_base.h"
namespace AscendC {
template <
class A_TYPE, class B_TYPE, class C_TYPE, class BIAS_TYPE, const auto& MM_CFG, class MM_CB,
MATMUL_POLICY_TEMPLATE_OF(MATMUL_POLICY)>
class MatmulImpl<
A_TYPE, B_TYPE, C_TYPE, BIAS_TYPE, MM_CFG, MM_CB, MATMUL_POLICY,
enable_if_t<A_TYPE::layout == LayoutMode::NONE && !isMxMatmul<A_TYPE, B_TYPE>>>
: public MatmulImplBase<A_TYPE, B_TYPE, C_TYPE, BIAS_TYPE, MM_CFG, MM_CB, MATMUL_POLICY> {
private:
using SrcT = typename A_TYPE::T;
using DstT = typename C_TYPE::T;
using IMPL = MatmulImpl<A_TYPE, B_TYPE, C_TYPE, BIAS_TYPE, MM_CFG, MM_CB, MATMUL_POLICY>;
MATMUL_USE_MODULE(Scheduler);
MATMUL_USE_MODULE(MatmulShapeTiling);
MATMUL_USE_MODULE(MatmulAntiQuantProcessor);
MATMUL_USE_MODULE(MatmulSubBlockInfo);
MATMUL_USE_MODULE(QtableProcessor);
public:
using BASE_MODULE = MatmulImplBase<A_TYPE, B_TYPE, C_TYPE, BIAS_TYPE, MM_CFG, MM_CB, MATMUL_POLICY>;
__aicore__ inline MatmulImpl() {}
__aicore__ inline void SetAntiQuantScalar(const SrcT offsetScalar, const SrcT scaleScalar)
{
MATMUL_MODULE(MatmulAntiQuantProcessor)->SetAntiQuantScalar(offsetScalar, scaleScalar);
}
__aicore__ inline void SetAntiQuantVector(
const LocalTensor<SrcT>& offsetTensor, const LocalTensor<SrcT>& scaleTensor)
{
MATMUL_MODULE(MatmulAntiQuantProcessor)->SetAntiQuantVector(offsetTensor, scaleTensor);
}
#if defined(__NPU_ARCH__) && (__NPU_ARCH__ == 1001 || __NPU_ARCH__ == 2002)
template <bool sync = true>
__aicore__ inline void IterateAll(
const GlobalTensor<DstT>& gm, uint8_t enAtomic = 0, bool enSequentialWrite = false, bool waitIterateAll = false,
bool fakeMsg = false)
{
ASCENDC_ASSERT((!ToMatmulConfig(MM_CFG).isPartialOutput), {
KERNEL_LOG(KERNEL_ERROR, "IterateAll is not supported for PartialOutput.");
});
#if defined(__NPU_ARCH__) && __NPU_ARCH__ == 2002
GlobalTensor<uint64_t> global;
global.SetGlobalBuffer((__gm__ uint64_t*)0);
DataCacheCleanAndInvalid<uint64_t, CacheLine::ENTIRE_DATA_CACHE>(global);
#endif
const auto tiling = MATMUL_MODULE(MatmulShapeTiling)->GetTiling();
while (BASE_MODULE::Iterate()) {
BASE_MODULE::GetTensorC(gm, enAtomic);
if constexpr (ToMatmulConfig(MM_CFG).enableUBReuse && !ToMatmulConfig(MM_CFG).enableL1CacheUB) {
event_t eventIDMte3ToMte2 = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE3_MTE2));
SetFlag<HardEvent::MTE3_MTE2>(eventIDMte3ToMte2);
WaitFlag<HardEvent::MTE3_MTE2>(eventIDMte3ToMte2);
} else if constexpr (ToMatmulConfig(MM_CFG).enableL1CacheUB) {
if ((tiling.GetDepthAL1CacheUB() == 0 && A_TYPE::format == CubeFormat::ND) ||
(tiling.GetDepthBL1CacheUB() == 0 && B_TYPE::format == CubeFormat::ND)) {
event_t eventIDMte3ToMte2 = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE3_MTE2));
SetFlag<HardEvent::MTE3_MTE2>(eventIDMte3ToMte2);
WaitFlag<HardEvent::MTE3_MTE2>(eventIDMte3ToMte2);
}
}
}
}
template <bool sync = true>
__aicore__ inline void IterateAll(const LocalTensor<DstT>& ubCmatrix, uint8_t enAtomic = 0)
{
ASCENDC_ASSERT((!ToMatmulConfig(MM_CFG).isPartialOutput), {
KERNEL_LOG(KERNEL_ERROR, "IterateAll is not supported for PartialOutput.");
});
#if defined(__NPU_ARCH__) && __NPU_ARCH__ == 2002
GlobalTensor<uint64_t> global;
global.SetGlobalBuffer((__gm__ uint64_t*)0);
DataCacheCleanAndInvalid<uint64_t, CacheLine::ENTIRE_DATA_CACHE>(global);
#endif
(void)(enAtomic);
while (BASE_MODULE::Iterate()) {
BASE_MODULE::GetTensorC(ubCmatrix);
event_t eventIDVToMte2 = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_MTE2));
SetFlag<HardEvent::V_MTE2>(eventIDVToMte2);
WaitFlag<HardEvent::V_MTE2>(eventIDVToMte2);
}
}
#else
template <bool sync = true>
__aicore__ inline void IterateAll(
const GlobalTensor<DstT>& gm, uint8_t enAtomic = 0, bool enSequentialWrite = false, bool waitIterateAll = false,
bool fakeMsg = false)
{
ASCENDC_ASSERT((!ToMatmulConfig(MM_CFG).isPartialOutput), {
KERNEL_LOG(KERNEL_ERROR, "IterateAll is not supported for PartialOutput.");
});
if constexpr (BASE_MODULE::POLICY::POLICY_TYPE == PolicyType::MATMUL_NBUFFER_33) {
static_assert(DoMatmulMDL(MM_CFG), "NBuffer33MatmulPolicy only support MDL config.");
MATMUL_MODULE(Scheduler)->Schedule(gm, enAtomic, enSequentialWrite);
} else {
if constexpr (ToMatmulConfig(MM_CFG).intraBlockPartSum) {
MATMUL_MODULE(MatmulSubBlockInfo)->SetFakeMsg(fakeMsg);
MATMUL_MODULE(Scheduler)->Schedule(gm, enAtomic, enSequentialWrite, fakeMsg);
} else {
while (BASE_MODULE::Iterate()) {
BASE_MODULE::GetTensorC(gm, enAtomic);
}
}
}
}
template <bool sync = true>
__aicore__ inline void IterateAll(const LocalTensor<DstT>& ubCmatrix, uint8_t enAtomic = 0)
{
ASCENDC_ASSERT((!ToMatmulConfig(MM_CFG).isPartialOutput), {
KERNEL_LOG(KERNEL_ERROR, "IterateAll is not supported for PartialOutput.");
});
if constexpr (BASE_MODULE::POLICY::POLICY_TYPE == PolicyType::MATMUL_NBUFFER_33) {
static_assert(DoMatmulMDL(MM_CFG), "NBuffer33MatmulPolicy only support MDL config.");
MATMUL_MODULE(Scheduler)->Schedule(ubCmatrix, enAtomic);
} else {
if constexpr (PhyPosIsL0C(C_TYPE::pos)) {
int64_t dstOffset = 0;
while (BASE_MODULE::Iterate(false, ubCmatrix[dstOffset])) {
dstOffset += MATMUL_MODULE(Scheduler)->GetL0cOffset();
BASE_MODULE::GetTensorC(ubCmatrix, enAtomic);
}
} else {
while (BASE_MODULE::Iterate(false)) {
BASE_MODULE::GetTensorC(ubCmatrix, enAtomic);
}
}
}
}
#if __NPU_ARCH__ == 5102
__aicore__ inline void WaitIterateAll()
{
event_t eventIDFixToMte2 = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::FIX_MTE2));
SetFlag<HardEvent::FIX_MTE2>(eventIDFixToMte2);
WaitFlag<HardEvent::FIX_MTE2>(eventIDFixToMte2);
}
__aicore__ inline void SetLookupTable(const GlobalTensor<uint64_t>& qtableTensor)
{
MATMUL_MODULE(QtableProcessor)->SetLookupTable(qtableTensor);
}
#endif
#endif
};
}
#endif
#if defined(__UNDEF_ASCENDC_INCLUDE_INTERNAL_HEADERS_DETAIL_MATMUL_MATMUL_IMPL_H__)
#undef __ASCENDC_INCLUDE_INTERNAL_HEADERS__
#undef __UNDEF_ASCENDC_INCLUDE_INTERNAL_HEADERS_DETAIL_MATMUL_MATMUL_IMPL_H__
#endif