* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file conv3d_bp_util.h
* \brief
*/
#ifndef CONV3D_BP_UTIL_H
#define CONV3D_BP_UTIL_H
#include "kernel_utils.h"
namespace ConvBackpropApi {
static __aicore__ inline uint64_t ShiftCeilBlockCube(uint64_t a)
{
return (a + 15) >> 4;
}
static __aicore__ inline uint64_t CalcDiv(uint64_t a, uint32_t b)
{
if (b > a) {
return 0;
} else if (b == a) {
return 1;
} else if (b == 1) {
return a;
} else {
return a / b;
}
}
static __aicore__ inline uint64_t CalcFloorAlign(uint64_t a, uint32_t b)
{
if (b >= a) {
return b;
} else if (b == 1) {
return a;
} else {
return a / b * b;
}
}
static __aicore__ inline uint64_t CalcRemainder(uint64_t a, uint32_t b)
{
if (b > a) {
return a;
} else if (b == 1 || b == a) {
return 0;
} else {
return a % b;
}
}
const uint32_t STEP_2 = 2;
static __aicore__ inline uint64_t Ceil(uint64_t a, uint32_t b)
{
#ifdef ASCENDC_CPU_DEBUG
ASCENDC_ASSERT(b != 0, { KERNEL_LOG(KERNEL_ERROR, "Division by zero error"); });
#endif
return (a + b - 1) / b;
}
static __aicore__ inline uint64_t DivStepM(uint64_t a, uint32_t stepM)
{
if (stepM == 1) {
return a;
} else if (stepM == STEP_2) {
return a >> 1;
} else {
return a / stepM;
}
}
static __aicore__ inline uint64_t CeilStepM(uint64_t a, uint32_t stepM)
{
if (stepM == 1) {
return a;
} else if (stepM == STEP_2) {
return (a + 1) >> 1;
} else {
return (a + stepM - 1) / stepM;
}
}
static __aicore__ inline uint64_t RemainderStepM(uint64_t a, uint32_t stepM)
{
if (stepM == 1) {
return 0;
} else if (stepM == STEP_2) {
return a & 1;
} else {
return a % stepM;
}
}
static __aicore__ inline uint64_t DivStepN(uint64_t a, uint32_t stepN)
{
if (stepN == 1) {
return a;
} else if (stepN == STEP_2) {
return a >> 1;
} else {
return a / stepN;
}
}
static __aicore__ inline uint64_t CeilStepN(uint64_t a, uint32_t stepN)
{
if (stepN == 1) {
return a;
} else if (stepN == STEP_2) {
return (a + 1) >> 1;
} else {
return (a + stepN - 1) / stepN;
}
}
static __aicore__ inline uint64_t RemainderStepN(uint64_t a, uint32_t stepN)
{
if (stepN == 1) {
return 0;
} else if (stepN == STEP_2) {
return a & 1;
} else {
return a % stepN;
}
}
static __aicore__ inline uint64_t DivHkWk(uint64_t a, uint32_t hkWk)
{
if (hkWk > a) {
return 0;
} else if (hkWk == 1) {
return a;
} else {
return a / hkWk;
}
}
static __aicore__ inline uint64_t CeilHkWk(uint64_t a, uint32_t hkWk)
{
if (hkWk > a) {
return 1;
} else if (hkWk == 1) {
return a;
} else {
return Ceil(a, hkWk);
}
}
static __aicore__ inline uint64_t RemainderOfHkWk(uint64_t a, uint32_t hkWk)
{
if (hkWk > a) {
return a;
} else if (hkWk == 1) {
return 0;
} else {
return a % hkWk;
}
}
template <class Intf>
static __aicore__ inline uint64_t ShiftDivChannelSize(uint64_t a, uint32_t channelSize)
{
if constexpr (AscendC::IsSameType<typename Intf::SrcT, float>::value) {
return a >> 3;
} else if constexpr (AscendC::IsSameType<typename Intf::SrcT, half>::value) {
return a >> 4;
} else if constexpr (AscendC::IsSameType<typename Intf::SrcT, bfloat16_t>::value) {
return a >> 4;
} else {
return a >> 4;
}
}
template <class Intf>
static __aicore__ inline uint64_t ShiftCeilChannelSize(uint64_t a, uint32_t channelSize)
{
if constexpr (AscendC::IsSameType<typename Intf::SrcT, float>::value) {
return (a + 7) >> 3;
} else if constexpr (AscendC::IsSameType<typename Intf::SrcT, half>::value) {
return (a + 15) >> 4;
} else if constexpr (AscendC::IsSameType<typename Intf::SrcT, bfloat16_t>::value) {
return (a + 15) >> 4;
} else {
return (a + 15) >> 4;
}
}
static __aicore__ inline uint64_t ShiftDivM0(uint64_t a, uint32_t m0)
{
return a >> 4;
}
static __aicore__ inline uint64_t ShiftCeilM0(uint64_t a, uint32_t m0)
{
return (a + 15) >> 4;
}
static __aicore__ inline uint32_t CalRows2Copy(uint32_t copySize, uint32_t width)
{
uint32_t rows = 1;
if (copySize > width) {
rows = AscendC::Ceil(copySize, width);
if (copySize == rows * width) {
return rows;
} else if ((2 * copySize) % width != 0) {
return rows + 1;
}
return rows;
} else if (copySize == width) {
return rows;
} else {
if (width % copySize != 0) {
return rows + 1;
}
return rows;
}
}
#define DECLARE_DEFAULT_OVERLOADING_FUN(T, NAMESPACE) \
template <class... Ts> \
static __aicore__ inline NAMESPACE::TypeFalse call(T *self, Ts... args) \
{ \
return ((NAMESPACE::TypeFalse){0}); \
}
#define CHECK_FUN(T, NAMESPACE, ...) (!AscendC::IsSameType<decltype(T::call(__VA_ARGS__)), NAMESPACE::TypeFalse>::value)
Define a verification template class to determine whether type T has the template member function MEMBER<U>
Used with macro DECLARE_IMPL锛�Calling method_has_impl_MEMBER<T, U>::value
Line 49: decltype gets the type of expression. declval is a template function that gets the rvalue reference of template parameter T. If T does not have MEMBER<U>, an error will be reported, otherwise TrueType will be returned
*/
#define DECLARE_CHECK_IMPL(NAMESPACE, MEMBER, args...) \
namespace __##NAMESPACE { \
template <typename T, typename U> \
struct _has_impl_##MEMBER { \
template <typename C> \
static auto check(int) -> decltype(std::declval<typename C::template MEMBER<U, ##args>>(), AscendC::TrueType()); \
template <typename C> \
static AscendC::FalseType check(...); \
static constexpr bool value = AscendC::IsSameType<decltype(check<T>(0)), AscendC::TrueType>::value; \
}; \
}
#define DECLARE_CHECK_SYNC_IMPL(NAMESPACE, MEMBER, args...) \
namespace __##NAMESPACE { \
template <typename T, typename U, bool sync> \
struct _has_impl_##MEMBER { \
template <typename C> \
static auto check(int) -> decltype(std::declval<typename C::template MEMBER<U, sync, ##args>>(), AscendC::TrueType()); \
template <typename C> \
static AscendC::FalseType check(...); \
static constexpr bool value = AscendC::IsSameType<decltype(check<T>(0)), AscendC::TrueType>::value; \
}; \
}
#define DECLARE_IMPL(NAMESPACE, Config, Current, MEMBER, U) \
template <bool Default, class T> \
struct __##MEMBER { \
using Type = typename Current::MEMBER<U>; \
}; \
template <class T> \
struct __##MEMBER<true, T> { \
using Type = typename T::template MEMBER<U>; \
}; \
using MEMBER = typename __##MEMBER<__##NAMESPACE::_has_impl_##MEMBER<Config, U>::value, Config>::Type
#define DECLARE_SYNC_IMPL(NAMESPACE, Config, Current, MEMBER, U) \
template <bool Default, class T, bool sync> \
struct __##MEMBER { \
using Type = typename Current::MEMBER<U, sync>; \
}; \
template <class T, bool sync> \
struct __##MEMBER<true, T, sync> { \
using Type = typename T::template MEMBER<U, sync>; \
}; \
template <bool sync> \
using MEMBER = typename __##MEMBER<__##NAMESPACE::_has_impl_##MEMBER<Config, U, sync>::value, Config, sync>::Type
Define a template class to determine whether type T has member MEMBER
Used in conjunction with macro CHECK_MEMBER, calling method _has_member_MEMBER<T>::value
*/
#define DECLARE_CHECK_MEMBER(MEMBER) \
namespace __AuxCheck { \
template <typename T> \
struct _has_member_##MEMBER { \
template <typename U> \
static void check((decltype(&U::MEMBER))); \
template <typename U> \
static int check(...); \
static constexpr bool value = IsSameType<decltype(check<T>(nullptr)), void>::value; \
}; \
}
#define CHECK_MEMBER(OBJ, MEMBER) (__AuxCheck::_has_member_##MEMBER<typename OBJ>::value)
Define two auxiliary template classes, one member M is a variable, and the other member M is a constant;
At the same time, a verification template function is defined. The function returns the corresponding template class based on whether the template parameter T has a constant and a member M with a value > 0.
Used with macro DEFINE_STUCT锛?
*/
#define DECLARE_DEFINE_STRUCT(T, M, U) \
namespace __AuxTiling { \
template <typename TT> \
struct T##_##M { \
U M; \
constexpr static bool __CONST_TYPE_##M = false; \
}; \
template <typename TT> \
struct T##_CT_##M { \
constexpr static U M = (TT::M); \
constexpr static bool __CONST_TYPE_##M = true; \
}; \
template <typename TT> \
constexpr bool _is_const_##T##_##M() \
{ \
return (TT::M) > 0; \
}; \
template <typename TT> \
typename std::conditional<_is_const_##T##_##M<TT>(), T##_CT_##M<TT>, T##_##M<TT>>::type T##_##M##_checkdefine(); \
}
#define DEFINE_STUCT(T, M) \
public \
decltype(__AuxTiling::T##_##M##_checkdefine<T>())
#define DEFINE_STUCT_FIELD(T, FIELD) \
T FIELD; \
constexpr static bool __CONST_TYPE_##FIELD = false
#define CHECK_CONST(T, M) ((T::__CONST_TYPE_##M))
#define DEFINE_STUCT_TEMPLATE_FIELD(T, FIELD, C, ...) \
T<C, ##__VA_ARGS__> FIELD; \
constexpr static bool __CONST_TYPE_##FIELD = false
#define DEFINE_STUCT_ARRAY_FIELD(T, FIELD, NUM) \
T FIELD[NUM]; \
constexpr static bool __CONST_TYPE_##FIELD = false
}
#endif