* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file conv3d_common_func.h
* \brief
*/
#ifndef API_CONV3D_COMMON_FUNC_H
#define API_CONV3D_COMMON_FUNC_H
#include "conv3d_iterate_impl.h"
namespace Conv3dApiFunc {
CONV_DECLARE_REG_IMPL(Init);
CONV_DECLARE_REG_IMPL(SetOrgInputShape);
CONV_DECLARE_REG_IMPL(SetOrgWeightShape);
CONV_DECLARE_REG_IMPL(SetOrgOutputShape);
CONV_DECLARE_REG_IMPL(SetSingleInputShape);
CONV_DECLARE_REG_IMPL(SetSingleOutputShape);
CONV_DECLARE_REG_IMPL(SetInputStartPosition);
CONV_DECLARE_REG_IMPL(SetGroupOptInfo);
CONV_DECLARE_REG_IMPL(Iterate);
CONV_DECLARE_REG_IMPL(GetTensorC);
CONV_DECLARE_REG_IMPL(IterateAll);
using TypeFalse = const struct {
__uint128_t _[1024];
};
template <class Intf, uint32_t ImplType>
struct GetTensorC {
template <bool sync = true>
static __aicore__ inline bool call(
Intf *self, const AscendC::GlobalTensor<typename Intf::OutputT> &output, bool enSequentialWrite = false)
{
self->ctx.copyOutIns.CopyOut(output);
self->ctx.queueCL0.FreeTensor(self->ctx.cl0);
if (self->ctx.enableBias) {
if (!self->ctx.biasFullLoadFlag) {
self->ctx.queueBiasL1.FreeTensor(self->ctx.biasL1);
}
if constexpr (Intf::formatType != ConvCommonApi::ConvFormat::NCDHW) {
self->ctx.queueBiasBT.FreeTensor(self->ctx.biasBT);
}
}
KERNEL_LOG(KERNEL_DEBUG, "[GetTensorC] GetTensorC Success! \n\n");
return false;
}
};
template <class Intf, uint32_t ImplType>
struct IterateAll {
template <bool sync = true>
static __aicore__ inline bool call(
Intf *self, const AscendC::GlobalTensor<typename Intf::OutputT> &output, bool enPartialSum = false)
{
self->ctx.loadBiasL1Ins.SetParams(self);
self->ctx.loadBL1Ins.SetParams(self);
self->ctx.loadAl1Ins.SetParams(self);
self->ctx.loadBL0Ins.SetParams(self);
self->ctx.madIns.SetParams(self);
self->ctx.copyOutIns.SetParams(self);
self->ctx.loadBiasBTIns.SetParams(self);
if constexpr (Intf::formatType == ConvCommonApi::ConvFormat::NCDHW) {
self->ctx.loadAL0Ins.SetParams(self);
} else {
self->ctx.loadAL0Ins.SetParams(self, &self->ctx.loadAl1Ins);
}
if constexpr (Intf::groupConvType) {
IterateAllWithGroups(self, output, enPartialSum);
} else {
IterateAllBase(self, output, enPartialSum);
}
return false;
}
static __aicore__ void inline IterateAllBase(
Intf *self, const AscendC::GlobalTensor<typename Intf::OutputT> &output, bool enPartialSum = false)
{
if (self->ctx.biasFullLoadFlag && self->ctx.enableBias) {
self->ctx.biasL1 = self->ctx.queueBiasL1.template AllocTensor<typename Intf::BiasT>();
self->ctx.loadBiasL1Ins.LoadChannelWiseL1(self->ctx.biasL1, self->ctx.biasgm);
self->ctx.queueBiasL1.EnQue(self->ctx.biasL1);
self->ctx.biasL1 = self->ctx.queueBiasL1.template DeQue<typename Intf::BiasT>();
}
while (Iterate<Intf, ImplType>::call(self, enPartialSum)) {
GetTensorC<Intf, ImplType>::call(self, output);
if constexpr (Intf::formatType != ConvCommonApi::ConvFormat::NCDHW) {
if (self->ctx.enableBias) {
self->ctx.queueBiasBT.FreeAllEvent();
}
}
}
if (self->ctx.biasFullLoadFlag && self->ctx.enableBias) {
self->ctx.queueBiasL1.FreeTensor(self->ctx.biasL1);
}
self->ctx.isFirstIterate = true;
self->ctx.nBL0Iter = 0;
self->ctx.nBL1Iter = 0;
}
static __aicore__ void inline ReCalculationKTilingWithGroups(
Intf *self, uint64_t &updateKAL1, uint64_t &updateKBL1, uint64_t &updateKL0)
{
uint64_t curKAL1Kd = Conv3dApi::GetCurrentKD(
self->ctx.conv3dTiling->kAL1, ConvApi::AlignB(self->ctx.orgCi, self->ctx.cin0), self->ctx.kernelHxkernelW);
uint64_t curKBL1Kd = Conv3dApi::GetCurrentKD(
self->ctx.conv3dTiling->kBL1, ConvApi::AlignB(self->ctx.orgCi, self->ctx.cin0), self->ctx.kernelHxkernelW);
uint64_t curCinxKhxKw = ConvApi::AlignB(self->ctx.singleCoreCin, self->ctx.cin0) * self->ctx.kernelHxkernelW;
updateKAL1 = curCinxKhxKw > self->ctx.conv3dTiling->kAL1 ? 0 : curCinxKhxKw;
updateKBL1 = curCinxKhxKw > self->ctx.conv3dTiling->kBL1 ? 0 : curCinxKhxKw;
if (curKAL1Kd > 1) {
updateKAL1 = curKAL1Kd * curCinxKhxKw;
}
if (updateKAL1 == 0) {
updateKAL1 =
curCinxKhxKw % self->ctx.conv3dTiling->kAL1 == 0 ? 0 : self->ctx.cin0 * self->ctx.kernelHxkernelW;
}
if (curKBL1Kd > 1) {
updateKBL1 = curKBL1Kd * curCinxKhxKw;
}
if (updateKBL1 == 0) {
updateKBL1 =
curCinxKhxKw % self->ctx.conv3dTiling->kBL1 == 0 ? 0 : self->ctx.cin0 * self->ctx.kernelHxkernelW;
}
if (updateKAL1 % self->ctx.conv3dTiling->kL0 != 0 || updateKBL1 % self->ctx.conv3dTiling->kL0 != 0) {
updateKL0 = self->ctx.cin0;
}
}
static __aicore__ void inline PreProcessGroupOptDimTail(Intf *self)
{
if (!self->ctx.isGroupOptDimTail) {
return;
}
if (self->ctx.singleCoreCinTail != 0) {
KERNEL_LOG(KERNEL_DEBUG, "[IterateAllWithGroups] singleCoreCin %d update to %d \n",
self->ctx.singleCoreCin,
self->ctx.singleCoreCinTail);
self->ctx.singleCoreCin = self->ctx.singleCoreCinTail;
uint64_t updateKAL1 = 0;
uint64_t updateKBL1 = 0;
uint64_t updateKL0 = 0;
ReCalculationKTilingWithGroups(self, updateKAL1, updateKBL1, updateKL0);
InitKDirectionBaseValue<Intf>(self, updateKAL1, updateKBL1, updateKL0);
self->ctx.preloadAL1DbFlag = false;
self->ctx.preloadABL1DbFlag = false;
KERNEL_LOG(KERNEL_DEBUG, "[IterateAllWithGroups] updateKAL1 %d updateKBL1 %d updateKL0 %d \n",
updateKAL1,
updateKBL1,
updateKL0);
}
if (self->ctx.singleCoreCoutTail != 0) {
KERNEL_LOG(KERNEL_DEBUG, "[IterateAllWithGroups] singleCoreCo %d update to %d \n",
self->ctx.singleCoreCo,
self->ctx.singleCoreCoutTail);
self->ctx.singleCoreCo = self->ctx.singleCoreCoutTail;
InitCoutDirectionBaseValue<Intf>(self);
}
}
static __aicore__ void inline PostProcessGroupOptDimTail(Intf *self, const uint64_t &tmpSingleCoreCo,
const uint8_t &tmpPreloadAL1DbFlag, const uint8_t &tmpPreloadABL1DbFlag)
{
if (!self->ctx.isGroupOptDimTail) {
return;
}
if (self->ctx.singleCoreCin != self->ctx.conv3dTiling->cinOpt) {
self->ctx.singleCoreCin = self->ctx.conv3dTiling->cinOpt;
InitKDirectionBaseValue<Intf>(self);
self->ctx.preloadAL1DbFlag = tmpPreloadAL1DbFlag;
self->ctx.preloadABL1DbFlag = tmpPreloadABL1DbFlag;
}
if (self->ctx.singleCoreCo != tmpSingleCoreCo) {
self->ctx.singleCoreCo = tmpSingleCoreCo;
InitCoutDirectionBaseValue<Intf>(self);
}
self->ctx.isGroupOptDimTail = false;
}
static __aicore__ void inline IterateAllWithGroups(
Intf *self, const AscendC::GlobalTensor<typename Intf::OutputT> &output, bool enPartialSum = false)
{
uint64_t weightOneGroupOptSize =
self->ctx.conv3dTiling->cinOpt * self->ctx.kernelHxkernelWxkernelD * self->ctx.conv3dTiling->coutOpt;
while (self->ctx.groupOptIter < self->ctx.maxGroupOptIter - 1) {
IterateAllBase(self, output, enPartialSum);
self->SetWeight(self->ctx.bgm[weightOneGroupOptSize]);
if (self->ctx.enableBias) {
self->SetBias(self->ctx.biasgm[self->ctx.conv3dTiling->coutOpt]);
}
self->ctx.groupOptIter++;
}
uint64_t tmpSingleCoreCo = self->ctx.singleCoreCo;
uint8_t tmpPreloadAL1DbFlag = self->ctx.preloadAL1DbFlag;
uint8_t tmpPreloadABL1DbFlag = self->ctx.preloadABL1DbFlag;
PreProcessGroupOptDimTail(self);
IterateAllBase(self, output, enPartialSum);
PostProcessGroupOptDimTail(self, tmpSingleCoreCo, tmpPreloadAL1DbFlag, tmpPreloadABL1DbFlag);
self->ctx.groupOptIter = 0;
}
};
}
#endif