* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file reflection_pad3d_grad_flat.h
* \brief
*/
#ifndef REFLECTION_PAD3D_GRAD_FLAT_H
#define REFLECTION_PAD3D_GRAD_FLAT_H
#include "reflection_pad3d_grad_init.h"
template <typename T>
__aicore__ inline void ReflectionPad3dGrad<T>::FlatProcess()
{
if (width < MAX_LINE + MAX_LINE) {
WidthFlatProcess();
} else {
HeightFlatProcess();
}
}
template <typename T>
__aicore__ inline void ReflectionPad3dGrad<T>::WidthFlatProcess()
{
event_t eventIDMTE3ToMTE2 = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE3_MTE2));
int32_t blockHeight = FloorDiv(ubFactorElement, CeilAlign(width, 16));
int32_t gmLoop_Height = CeilDiv(height - MAX_LINE - MAX_LINE, blockHeight);
for (size_t loop = 0; loop < loopNC; loop++) {
for (size_t i = 0; i < curDepth; i++) {
size_t cur_D = GetCurD(i);
bool isAtomicAdd = true;
int64_t gmXOffset = (loop * curDepth * height * width + i * height * width);
int64_t gmYOffset = (loop * curOutDepth * outHeight * outWidth + cur_D * outHeight * outWidth);
CopyInSmall(gmXOffset, MAX_LINE, width, 0);
ComputeSmall(0, 1, 0, 0, MAX_LINE, width);
CopyOutSmall(
gmYOffset, hPad1 * CeilAlign(width, 16), isAtomicAdd, MAX_LINE - hPad1, outWidth, CeilAlign(width, 16),
0);
SetFlag<HardEvent::MTE3_MTE2>(eventIDMTE3ToMTE2);
WaitFlag<HardEvent::MTE3_MTE2>(eventIDMTE3ToMTE2);
for (size_t j = 0; j < gmLoop_Height; j++) {
gmXOffset =
(loop * curDepth * height * width + i * height * width + MAX_LINE * width +
j * blockHeight * width);
gmYOffset =
(loop * curOutDepth * outHeight * outWidth + cur_D * outHeight * outWidth +
(MAX_LINE - hPad1) * outWidth + j * blockHeight * outWidth);
int64_t calHeight = blockHeight;
if (j == gmLoop_Height - 1) {
calHeight = height - MAX_LINE - MAX_LINE - j * blockHeight;
}
CopyInSmall(gmXOffset, calHeight, width, 0);
ComputeSmall(1, 1, 0, 0, calHeight, width);
CopyOutSmall(gmYOffset, 0, isAtomicAdd, calHeight, outWidth, CeilAlign(width, 16), 0);
SetFlag<HardEvent::MTE3_MTE2>(eventIDMTE3ToMTE2);
WaitFlag<HardEvent::MTE3_MTE2>(eventIDMTE3ToMTE2);
}
gmXOffset = (loop * curDepth * height * width + i * height * width + (height - MAX_LINE) * width);
gmYOffset =
(loop * curOutDepth * outHeight * outWidth + cur_D * outHeight * outWidth +
(outHeight - (MAX_LINE - hPad2)) * outWidth);
CopyInSmall(gmXOffset, MAX_LINE, width, 0);
ComputeSmall(1, 0, 0, 0, MAX_LINE, width);
CopyOutSmall(gmYOffset, 0, isAtomicAdd, MAX_LINE - hPad2, outWidth, CeilAlign(width, 16), 0);
SetFlag<HardEvent::MTE3_MTE2>(eventIDMTE3ToMTE2);
WaitFlag<HardEvent::MTE3_MTE2>(eventIDMTE3ToMTE2);
}
}
}
template <typename T>
__aicore__ inline void ReflectionPad3dGrad<T>::HeightFlatProcess()
{
event_t eventIDMTE3ToMTE2 = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE3_MTE2));
int32_t blockWidth = FloorAlign(FloorDiv(ubFactorElement, CeilAlign(height, 16)), 16);
int32_t gmLoop_Width = CeilDiv(width - MAX_LINE - MAX_LINE, blockWidth);
for (size_t loop = 0; loop < loopNC; loop++) {
for (size_t i = 0; i < curDepth; i++) {
size_t cur_D = GetCurD(i);
bool isAtomicAdd = true;
int64_t gmXOffset = (loop * curDepth * height * width + i * height * width);
int64_t gmYOffset = (loop * curOutDepth * outHeight * outWidth + cur_D * outHeight * outWidth);
CopyInSmall(gmXOffset, height, MAX_LINE, (width - MAX_LINE) * sizeof(T));
ComputeSmall(0, 0, 0, 1, height, MAX_LINE);
CopyOutSmall(
gmYOffset, hPad1 * MAX_LINE, isAtomicAdd, outHeight, MAX_LINE - wPad1, CeilAlign(MAX_LINE, 16),
(outWidth - (MAX_LINE - wPad1)) * sizeof(T));
SetFlag<HardEvent::MTE3_MTE2>(eventIDMTE3ToMTE2);
WaitFlag<HardEvent::MTE3_MTE2>(eventIDMTE3ToMTE2);
for (size_t j = 0; j < gmLoop_Width; j++) {
gmXOffset = (loop * curDepth * height * width + i * height * width + MAX_LINE + j * blockWidth);
gmYOffset =
(loop * curOutDepth * outHeight * outWidth + cur_D * outHeight * outWidth + (MAX_LINE - wPad1) +
j * blockWidth);
int64_t calWidth = blockWidth;
if (j == gmLoop_Width - 1) {
calWidth = width - MAX_LINE - MAX_LINE - j * blockWidth;
}
CopyInSmall(gmXOffset, height, calWidth, (width - calWidth) * sizeof(T));
ComputeSmall(0, 0, 1, 1, height, calWidth);
CopyOutSmall(
gmYOffset, hPad1 * CeilAlign(calWidth, 16), isAtomicAdd, outHeight, calWidth,
CeilAlign(calWidth, 16), (outWidth - calWidth) * sizeof(T));
SetFlag<HardEvent::MTE3_MTE2>(eventIDMTE3ToMTE2);
WaitFlag<HardEvent::MTE3_MTE2>(eventIDMTE3ToMTE2);
}
gmXOffset = (loop * curDepth * height * width + i * height * width + width - MAX_LINE);
gmYOffset =
(loop * curOutDepth * outHeight * outWidth + cur_D * outHeight * outWidth + outWidth -
(MAX_LINE - wPad2));
CopyInSmall(gmXOffset, height, MAX_LINE, (width - MAX_LINE) * sizeof(T));
ComputeSmall(0, 0, 1, 0, height, MAX_LINE);
CopyOutSmall(
gmYOffset, hPad1 * MAX_LINE, isAtomicAdd, outHeight, MAX_LINE - wPad2, CeilAlign(MAX_LINE, 16),
(outWidth - (MAX_LINE - wPad2)) * sizeof(T));
SetFlag<HardEvent::MTE3_MTE2>(eventIDMTE3ToMTE2);
WaitFlag<HardEvent::MTE3_MTE2>(eventIDMTE3ToMTE2);
}
}
}
#endif