* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file conv3d_common_set_func.h
* \brief
*/
#ifndef API_CONV3D_COMMON_SET_FUNC_H
#define API_CONV3D_COMMON_SET_FUNC_H
#include "../common/conv_forward_framework_util.h"
#include "conv3d_common_sub_api.h"
#include "include/adv_api/conv/conv3d/conv3d_config.h"
#include "kernel_basic_intf.h"
#include "kernel_tiling/kernel_tiling.h"
#include "kernel_utils.h"
namespace Conv3dApiFunc {
template <class Intf, uint32_t ImplType>
struct SetOrgInputShape {
static __aicore__ inline void call(Intf *self, uint64_t orgCi, uint64_t orgDi, uint64_t orgHi, uint64_t orgWi)
{
self->ctx.oriCi = orgCi;
self->ctx.orgDi = orgDi;
self->ctx.orgHi = orgHi;
self->ctx.orgWi = orgWi;
}
};
template <class Intf, uint32_t ImplType>
struct SetOrgWeightShape {
static __aicore__ inline void call(
Intf *self, uint64_t orgCo, uint64_t orgCi, uint64_t orgKd, uint64_t orgKh, uint64_t orgKw)
{
self->ctx.orgCo = orgCo;
self->ctx.orgCi = orgCi;
self->ctx.kernelD = orgKd;
self->ctx.kernelH = orgKh;
self->ctx.kernelW = orgKw;
}
};
template <class Intf, uint32_t ImplType>
struct SetOrgOutputShape {
static __aicore__ inline void call(Intf *self, uint64_t orgCo, uint64_t orgDo, uint64_t orgHo, uint64_t orgWo)
{
self->ctx.orgCo = orgCo;
self->ctx.orgDo = orgDo;
self->ctx.orgHo = orgHo;
self->ctx.orgWo = orgWo;
}
};
template <class Intf, uint32_t ImplType>
struct SetSingleInputShape {
static __aicore__ inline void call(
Intf *self, uint64_t singleCi, uint64_t singleDi, uint64_t singleHi, uint64_t singleWi)
{
self->ctx.singleCoreCin = singleCi;
InitKDirectionBaseValue<Intf>(self);
}
};
template <class Intf, uint32_t ImplType>
struct SetSingleOutputShape {
static __aicore__ inline void call(
Intf *self, uint64_t singleCo, uint64_t singleDo, uint64_t singleHo,
uint64_t singleWo, uint64_t singleGroupOpt)
{
self->ctx.singleCoreCo = singleCo;
self->ctx.singleCoreDo = singleDo;
self->ctx.singleCoreHo = singleHo;
self->ctx.singleCoreGroupOpt = singleGroupOpt;
InitCoutDirectionBaseValue<Intf>(self);
InitDoutDirectionBaseValue<Intf>(self);
InitGroupOptDirectionValue<Intf>(self);
}
static __aicore__ inline void call(
Intf *self, uint64_t singleCo, uint64_t singleDo, uint64_t singleCoreM, uint64_t singleGroupOpt)
{
self->ctx.singleCoreCo = singleCo;
self->ctx.singleCoreDo = singleDo;
self->ctx.singleCoreM = singleCoreM;
self->ctx.singleCoreGroupOpt = singleGroupOpt;
InitMDirectionBaseValue<Intf>(self);
InitCoutDirectionBaseValue<Intf>(self);
InitDoutDirectionBaseValue<Intf>(self);
InitGroupOptDirectionValue<Intf>(self);
}
};
template <class Intf, uint32_t ImplType>
struct SetInputStartPosition {
static __aicore__ inline void call(
Intf *self, int64_t diStartPos, int64_t hiStartPos, int64_t wiStartPos, int64_t ciStartPos)
{
self->ctx.diStartPos = diStartPos;
}
static __aicore__ inline void call(Intf *self, int64_t diStartPos, int64_t mStartPos, int64_t ciStartPos)
{
self->ctx.diStartPos = diStartPos;
self->ctx.mStartPos = mStartPos;
}
};
template <class Intf, uint32_t ImplType>
struct SetGroupOptInfo {
static __aicore__ inline void call(
Intf *self, uint64_t singleCoreCinTail, uint64_t singleCoreCoutTail, bool isGroupOptDimTail = false)
{
self->ctx.singleCoreCinTail = singleCoreCinTail;
self->ctx.singleCoreCoutTail = singleCoreCoutTail;
self->ctx.isGroupOptDimTail = isGroupOptDimTail;
}
};
}
#endif