* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file conv3d_api_impl.h
* \brief
*/
#ifndef API_CONV3D_API_IMPL_H
#define API_CONV3D_API_IMPL_H
#include "kernel_basic_intf.h"
#include "kernel_tiling/kernel_tiling.h"
#include "kernel_utils.h"
#include "conv3d_util.h"
#include "include/adv_api/conv/conv3d/conv3d_config.h"
#include "conv3d_common_func.h"
#include "../common/conv_forward_framework_util.h"
#include "../common/conv_forward_common_func.h"
#define QUE_DEPTH_SIZE_TWO 2
namespace Conv3dApi {
template <typename Intf, class Config_, bool OutputOrder = true>
struct Conv3dApiImpl {
public:
using Config = Config_;
using ConvParam = typename Config_::ConvParam;
constexpr static uint32_t ImplType = Config::implType;
CONV_DEFINE_MEMBER(ConvApi, ConvParam, outputOrder, true, bool);
CONV_DEFINE_MEMBER(ConvApi, ConvParam, l0pingpong, static_cast<int8_t>(ConvL0PingPong::ALL_CLOSE), int);
CONV_DEFINE_MEMBER(ConvApi, ConvParam, bl1bypass, static_cast<int8_t>(ConvBL1ByPass::BYPASS_ON), int);
CONV_DEFINE_MEMBER(ConvApi, ConvParam, groupConvType, static_cast<int8_t>(GroupConvType::NoGroup_Conv), int);
public:
__aicore__ inline Conv3dApiImpl(){};
CONV_REG_IMPL(Config, Conv3dApiFunc, Init);
CONV_REG_IMPL(Config, ConvApiFunc, SetInput);
CONV_REG_IMPL(Config, ConvApiFunc, SetWeight);
CONV_REG_IMPL(Config, ConvApiFunc, SetBias);
CONV_REG_IMPL(Config, Conv3dApiFunc, SetOrgInputShape);
CONV_REG_IMPL(Config, Conv3dApiFunc, SetOrgWeightShape);
CONV_REG_IMPL(Config, Conv3dApiFunc, SetOrgOutputShape);
CONV_REG_IMPL(Config, Conv3dApiFunc, SetSingleInputShape);
CONV_REG_IMPL(Config, Conv3dApiFunc, SetSingleOutputShape);
CONV_REG_IMPL(Config, Conv3dApiFunc, SetInputStartPosition);
CONV_REG_IMPL(Config, Conv3dApiFunc, SetGroupOptInfo);
CONV_REG_IMPL(Config, Conv3dApiFunc, Iterate);
CONV_REG_IMPL(Config, Conv3dApiFunc, IterateAll);
CONV_REG_IMPL(Config, Conv3dApiFunc, GetTensorC);
CONV_REG_IMPL(Config, ConvApiFunc, End);
struct ContextData : public Config::ContextData {
__aicore__ inline ContextData(){};
const struct TConv3DApiTiling *__restrict conv3dTiling;
AscendC::TPipe* pipe = nullptr;
using LoadAL1Tools = typename AscendC::Conditional<Config::formatA == ConvCommonApi::ConvFormat::NCDHW,
Conv3dApiFunc::LoadAL1WithPointWiseTools<Intf>, Conv3dApiFunc::LoadAL1Tools<Intf>>::type;
using LoadBL1ToolsTmp = typename AscendC::Conditional<Config::formatA == ConvCommonApi::ConvFormat::NCDHW,
Conv3dApiFunc::LoadBL1WithPointWiseTools<Intf>, Conv3dApiFunc::LoadBL1Tools<Intf>>::type;
using LoadBL1Tools =
typename AscendC::Conditional<groupConvType, Conv3dApiFunc::LoadBL1WithGroupOptTools<Intf>, LoadBL1ToolsTmp>::type;
using MMadTools = typename AscendC::Conditional<Config::formatA == ConvCommonApi::ConvFormat::NCDHW,
Conv3dApiFunc::MMadWithPointWiseTools<Intf>, Conv3dApiFunc::MMadTools<Intf>>::type;
using LoadBL0ToolsTmp = typename AscendC::Conditional<Config::formatA == ConvCommonApi::ConvFormat::NCDHW,
Conv3dApiFunc::LoadBL0WithPointWiseTools<Intf>, Conv3dApiFunc::LoadBL0Tools<Intf>>::type;
using LoadBL0Tools =
typename AscendC::Conditional<groupConvType, Conv3dApiFunc::LoadBL0WithGroupOptTools<Intf>, LoadBL0ToolsTmp>::type;
using LoadAL0Tools = typename AscendC::Conditional<
Config::formatA == ConvCommonApi::ConvFormat::NCDHW,
Conv3dApiFunc::LoadAL0WithPointWiseTools<Intf>, Conv3dApiFunc::LoadAL0Tools<Intf>>::type;
using LoadChannelWiseL1Tools = typename AscendC::Conditional<
Config::formatA == ConvCommonApi::ConvFormat::NCDHW,
Conv3dApiFunc::LoadBiasL1WithPointWiseTools<Intf, typename Config::BiasT>,
Conv3dApiFunc::LoadChannelWiseL1Tools<Intf, typename Config::BiasT>>::type;
using LoadBiasL0Tools = typename AscendC::Conditional<
Config::formatA == ConvCommonApi::ConvFormat::NCDHW,
Conv3dApiFunc::LoadBiasL0WithBroadcastTools<Intf>, Conv3dApiFunc::LoadBiasBtTools<Intf>>::type;
using CopyOutTools = typename AscendC::Conditional<
Config::formatA == ConvCommonApi::ConvFormat::NCDHW,
Conv3dApiFunc::CopyOutWithPointWiseTools<Intf>, Conv3dApiFunc::CopyOutTools<Intf>>::type;
LoadAL1Tools loadAl1Ins;
LoadBL1Tools loadBL1Ins;
MMadTools madIns;
LoadBL0Tools loadBL0Ins;
LoadAL0Tools loadAL0Ins;
LoadChannelWiseL1Tools loadBiasL1Ins;
LoadBiasL0Tools loadBiasBTIns;
CopyOutTools copyOutIns;
AscendC::GlobalTensor<typename Config::SrcAT> agm;
AscendC::GlobalTensor<typename Config::SrcBT> bgm;
AscendC::GlobalTensor<typename Config::BiasT> biasgm;
AscendC::LocalTensor<typename Config::SrcAT> al1 =
AscendC::LocalTensor<typename Config::SrcAT>(AscendC::TPosition::A1, 0, 0);
AscendC::LocalTensor<typename Config::SrcBT> bl1 =
AscendC::LocalTensor<typename Config::SrcBT>(AscendC::TPosition::B1, 0, 0);
AscendC::LocalTensor<typename Config::BiasT> biasL1 =
AscendC::LocalTensor<typename Config::BiasT>(AscendC::TPosition::A1, 0, 0);
AscendC::LocalTensor<typename Config::L0cT> biasBT =
AscendC::LocalTensor<typename Config::L0cT>(AscendC::TPosition::C2, 0, 0);
AscendC::LocalTensor<typename Config::SrcAT> al0 =
AscendC::LocalTensor<typename Config::SrcAT>(AscendC::TPosition::A2, 0, 0);
AscendC::LocalTensor<typename Config::SrcBT> bl0 =
AscendC::LocalTensor<typename Config::SrcBT>(AscendC::TPosition::B2, 0, 0);
AscendC::LocalTensor<typename Config::SrcAT> al0Ping =
AscendC::LocalTensor<typename Config::SrcAT>(AscendC::TPosition::A2, 0, 0);
AscendC::LocalTensor<typename Config::SrcAT> al0Pong =
AscendC::LocalTensor<typename Config::SrcAT>(AscendC::TPosition::A2, 0, 0);
AscendC::LocalTensor<typename Config::SrcBT> bl0Ping =
AscendC::LocalTensor<typename Config::SrcBT>(AscendC::TPosition::B2, 0, 0);
AscendC::LocalTensor<typename Config::SrcBT> bl0Pong =
AscendC::LocalTensor<typename Config::SrcBT>(AscendC::TPosition::B2, 0, 0);
AscendC::LocalTensor<typename Config::L0cT> cl0 =
AscendC::LocalTensor<typename Config::L0cT>(AscendC::TPosition::CO1, 0, 0);
AscendC::TQue<AscendC::QuePosition::A1, QUE_DEPTH_SIZE_TWO> queueAL1;
AscendC::TQue<AscendC::QuePosition::B1, QUE_DEPTH_SIZE_TWO> queueBL1;
AscendC::TQue<AscendC::QuePosition::A1, 1> queueBiasL1;
AscendC::TQue<AscendC::TPosition::C2, 1> queueBiasBT;
AscendC::TQue<AscendC::QuePosition::CO1, 1> queueCL0;
using L0aBufType = typename AscendC::Conditional<
Config::formatA == ConvCommonApi::ConvFormat::NCDHW,
AscendC::TBuf<AscendC::TPosition::B2>, AscendC::TBuf<AscendC::TPosition::A2>>::type;
using L0bBufType = typename AscendC::Conditional<
Config::formatA == ConvCommonApi::ConvFormat::NCDHW,
AscendC::TBuf<AscendC::TPosition::A2>, AscendC::TBuf<AscendC::TPosition::B2>>::type;
L0aBufType l0aBuf;
L0bBufType l0bBuf;
AscendC::LocalTensor<typename Config::L0cT> al0BiasB =
AscendC::LocalTensor<typename Config::L0cT>(AscendC::TPosition::A2, 0, 0);
AscendC::LocalTensor<typename Config::L0cT> bl0BiasB =
AscendC::LocalTensor<typename Config::L0cT>(AscendC::TPosition::B2, 0, 0);
uint8_t enableBias = false;
uint8_t isFirstIterate = true;
uint8_t loadAL1Flag = true;
uint8_t loadBL1Flag = true;
uint8_t loadAL0Flag = true;
uint8_t loadBL0Flag = true;
uint8_t kAL1fullload = false;
uint8_t kBL1fullload = false;
uint8_t biasFullLoadFlag = false;
uint8_t mL0IsDivisibleByWo = false;
uint8_t freeAL1TensorFlag = false;
uint8_t freeBL1TensorFlag = false;
uint8_t isGroupOptDimTail = false;
uint64_t kAL1Iter = 0;
uint64_t kBL1Iter = 0;
uint64_t mAL1Iter = 0;
uint64_t nBL1Iter = 0;
uint64_t dOutIter = 0;
uint64_t kIter = 0;
uint64_t kAL0Iter = 0;
uint64_t kBL0Iter = 0;
uint64_t mAL0Iter = 0;
uint64_t nBL0Iter = 0;
uint64_t groupOptIter = 0;
uint64_t maxKAL1Iter = 0;
uint64_t maxMAL1Iter = 0;
uint64_t maxNBL1Iter = 0;
uint64_t maxKBL1Iter = 0;
uint64_t maxNL0Iter = 0;
uint64_t maxML0Iter = 0;
uint64_t maxKL0Iter = 0;
uint64_t maxDOutIter = 0;
uint64_t maxGroupOptIter = 0;
uint64_t ddr2l1LoopN = 0;
uint64_t l12l0LoopN = 0;
uint64_t ddr2l1LoopD = 0;
uint64_t ddr2l1LoopM = 0;
uint64_t l12l0LoopM = 0;
uint64_t ddr2l0LoopK = 0;
uint64_t orgCi = 0;
uint64_t orgCo = 0;
uint64_t orgDi = 0;
uint64_t orgDo = 0;
uint64_t orgHi = 0;
uint64_t orgWi = 0;
uint64_t orgHo = 0;
uint64_t orgWo = 0;
uint64_t kernelD = 0;
uint64_t kernelH = 0;
uint64_t kernelW = 0;
uint64_t strideD = 0;
uint64_t strideH = 0;
uint64_t strideW = 0;
uint64_t dilationD = 0;
uint64_t dilationH = 0;
uint64_t dilationW = 0;
uint64_t padHead = 0;
uint64_t padTail = 0;
uint64_t padUp = 0;
uint64_t padDown = 0;
uint64_t padLeft = 0;
uint64_t padRight = 0;
uint64_t singleCoreCin = 0;
uint64_t singleCoreCo = 0;
uint64_t singleCoreM = 0;
uint64_t singleCoreDo = 0;
uint64_t dilatedKernelH = 0;
uint64_t dilatedKernelW = 0;
uint64_t dilatedKernelD = 0;
uint64_t cin0 = 0;
uint64_t cin1 = 0;
uint64_t cout1 = 0;
uint64_t kernelHxkernelW = 0;
uint64_t kernelHxkernelWxkernelD = 0;
uint64_t kL0Tail = 0;
uint64_t mAL1Tail = 0;
uint64_t mAL0Tail = 0;
uint64_t nL0Tail = 0;
uint64_t nBL1Tail = 0;
uint64_t multiKAL1 = 1;
uint64_t multiKBL1 = 1;
uint64_t mStartPos = 0;
uint64_t diStartPos = 0;
uint64_t orgCoAlignK0 = 0;
uint64_t orgCoAlignN0 = 0;
uint64_t nBL1TailAlign = 0;
uint64_t sizeOfInput = 0;
uint64_t sizeOfWeight = 0;
uint64_t sizeOfBias = 0;
uint64_t sizeOfL0c = 0;
uint64_t groupKAL1 = 0;
uint64_t groupKBL1 = 0;
uint64_t groupKAL1Tail = 0;
uint64_t groupKBL1Tail = 0;
uint64_t singleCoreKL0 = 0;
uint64_t preCorePerGroupSumCout = 0;
uint64_t singleCoreGroupOpt = 0;
uint64_t singleCoreCinTail = 0;
uint64_t singleCoreCoutTail = 0;
uint8_t preloadAL1DbFlag = false;
uint8_t preloadABL1DbFlag = false;
};
struct ImplDataType {
__aicore__ inline ImplDataType(){};
};
};
}
#endif