* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "aclnn_im2col.h"
#include "im2col.h"
#include "aclnn_kernels/cast.h"
#include "aclnn_kernels/contiguous.h"
#include "conversion/squeeze/op_host/op_api/squeeze.h"
#include "conversion/unsqueeze/op_host/op_api/unsqueeze.h"
#include "aclnn_kernels/transdata.h"
#include "aclnn_kernels/reshape.h"
#include "aclnn_kernels/common/op_error_check.h"
#include "opdev/common_types.h"
#include "opdev/data_type_utils.h"
#include "opdev/format_utils.h"
#include "opdev/op_dfx.h"
#include "opdev/op_executor.h"
#include "opdev/op_log.h"
#include "opdev/tensor_view_utils.h"
#include "opdev/shape_utils.h"
#include "op_api/aclnn_check.h"
using namespace op;
static constexpr size_t NEED_SQUEEZE = 3;
static constexpr size_t NO_NEED_SQUEEZE = 4;
static constexpr size_t ARRAY_SIZE = 2;
static constexpr size_t PADDING_SIZE = 4;
static const std::initializer_list<op::DataType> DTYPE_SUPPORT_LIST = {
op::DataType::DT_FLOAT, op::DataType::DT_FLOAT16, op::DataType::DT_BF16};
static const std::initializer_list<op::DataType> DTYPE_SUPPORT_LIST_REGBASE = {
op::DataType::DT_INT8, op::DataType::DT_UINT8, op::DataType::DT_INT16, op::DataType::DT_UINT16,
op::DataType::DT_INT32, op::DataType::DT_UINT32, op::DataType::DT_INT64, op::DataType::DT_UINT64,
op::DataType::DT_BF16, op::DataType::DT_FLOAT16, op::DataType::DT_FLOAT, op::DataType::DT_DOUBLE,
op::DataType::DT_BOOL, op::DataType::DT_COMPLEX32, op::DataType::DT_COMPLEX64};
template <typename T>
static inline auto div_rtn(T x, T y) -> T
{
if (y == 0) {
OP_LOGE(ACL_ERROR_INVALID_PARAM, "Division by zero!");
return -1;
}
int q = x / y;
int r = x % y;
if ((r != 0) && ((r < 0) != (y < 0))) {
--q;
};
return q;
}
#ifdef __cplusplus
extern "C" {
#endif
static inline bool CheckNotNull(
const aclTensor* self, const aclIntArray* kernelSize, const aclIntArray* dilation, const aclIntArray* padding,
const aclIntArray* stride, const aclTensor* out)
{
OP_CHECK_NULL(self, return false);
OP_CHECK_NULL(kernelSize, return false);
OP_CHECK_NULL(dilation, return false);
OP_CHECK_NULL(padding, return false);
OP_CHECK_NULL(stride, return false);
OP_CHECK_NULL(out, return false);
return true;
}
static bool CheckInputDims(const aclTensor* self)
{
auto selfDimNum = self->GetViewShape().GetDimNum();
if (selfDimNum != NEED_SQUEEZE && selfDimNum != NO_NEED_SQUEEZE) {
OP_LOGE(
ACLNN_ERR_PARAM_INVALID, "Expected self dim [%zu] to be 3 or 4 but check failed.",
self->GetViewShape().GetDimNum());
return false;
}
const op::Shape selfShape = self->GetViewShape();
size_t index = selfDimNum == NO_NEED_SQUEEZE ? 1 : 0;
for (size_t i = index; i < selfDimNum; i++) {
if (selfShape.GetDim(i) <= 0) {
OP_LOGE(
ACLNN_ERR_PARAM_INVALID, "self'dims is invalid, self No.[%lu] dim is [%ld].", i + 1,
selfShape.GetDim(i));
return false;
}
}
return true;
}
static bool CheckOutputDims(
const aclTensor* self, const aclIntArray* kernelSize, const aclIntArray* dilation, const aclIntArray* padding,
const aclIntArray* stride, const aclTensor* out)
{
bool isNeedSqueeze = (self->GetViewShape().GetDimNum() == NEED_SQUEEZE);
int64_t inputHeight = isNeedSqueeze ? self->GetViewShape().GetDim(1) : self->GetViewShape().GetDim(2);
int64_t inputWidth = isNeedSqueeze ? self->GetViewShape().GetDim(2) : self->GetViewShape().GetDim(3);
int64_t outputHeight =
div_rtn<int64_t>(
(inputHeight + 2 * (*padding)[0] - ((*dilation)[0] * ((*kernelSize)[0] - 1) + 1)), (*stride)[0]) +
1;
int64_t outputWidth =
div_rtn<int64_t>(
(inputWidth + 2 * (*padding)[1] - ((*dilation)[1] * ((*kernelSize)[1] - 1) + 1)), (*stride)[1]) +
1;
if (outputHeight < 1 || outputWidth < 1) {
OP_LOGE(
ACLNN_ERR_PARAM_INVALID,
"The shape (%ld, %ld) of the array calculated by other parameters "
"must be at least one.",
outputHeight, outputWidth);
return false;
}
const op::Shape outShape =
isNeedSqueeze ?
op::Shape(
{self->GetViewShape().GetDim(0) * (*kernelSize)[0] * (*kernelSize)[1], outputHeight * outputWidth}) :
op::Shape(
{self->GetViewShape().GetDim(0), self->GetViewShape().GetDim(1) * (*kernelSize)[0] * (*kernelSize)[1],
outputHeight * outputWidth});
if (outShape != out->GetViewShape()) {
OP_LOGE(
ACLNN_ERR_PARAM_INVALID, "Expect out shape [%s], but got: [%s].", op::ToString(outShape).GetString(),
op::ToString(out->GetViewShape()).GetString());
return false;
}
return true;
}
static bool CheckArray(
const aclIntArray* kernelSize, const aclIntArray* dilation, const aclIntArray* padding, const aclIntArray* stride)
{
if (kernelSize->Size() != ARRAY_SIZE) {
OP_LOGE(
ACLNN_ERR_PARAM_INVALID, "It is expected kernelSize equals to 2, but got size %lu.", kernelSize->Size());
return false;
}
if (dilation->Size() != ARRAY_SIZE) {
OP_LOGE(ACLNN_ERR_PARAM_INVALID, "It is expected dilation equals to 2, but got size %lu.", dilation->Size());
return false;
}
if (padding->Size() != ARRAY_SIZE) {
OP_LOGE(ACLNN_ERR_PARAM_INVALID, "It is expected padding equals to 2, but got size %lu.", padding->Size());
return false;
}
if (stride->Size() != ARRAY_SIZE) {
OP_LOGE(ACLNN_ERR_PARAM_INVALID, "It is expected stride equals to 2, but got size %lu.", stride->Size());
return false;
}
for (size_t i = 0; i < kernelSize->Size(); i++) {
if ((*kernelSize)[i] <= 0) {
OP_LOGE(
ACLNN_ERR_PARAM_INVALID,
"It is expected kernelSize be greater than zero, "
"but kernelSize No.[%lu] dim is [%ld].",
i + 1, (*kernelSize)[i]);
return false;
}
}
for (size_t i = 0; i < stride->Size(); i++) {
if ((*stride)[i] <= 0) {
OP_LOGE(
ACLNN_ERR_PARAM_INVALID,
"It is expected stride be greater than zero, "
"but stride No.[%lu] dim is [%ld].",
i + 1, (*stride)[i]);
return false;
}
}
for (size_t i = 0; i < dilation->Size(); i++) {
if ((*dilation)[i] <= 0) {
OP_LOGE(
ACLNN_ERR_PARAM_INVALID,
"It is expected dilation be greater than zero, "
"but dilation No.[%lu] dim is [%ld].",
i + 1, (*dilation)[i]);
return false;
}
}
for (size_t i = 0; i < padding->Size(); i++) {
if ((*padding)[i] < 0) {
OP_LOGE(
ACLNN_ERR_PARAM_INVALID,
"It is expected padding be greater than or equal to zero, "
"but padding No.[%lu] dim is [%ld].",
i + 1, (*padding)[i]);
return false;
}
}
return true;
}
static void CheckFormat(const aclTensor* self)
{
if (self->GetStorageFormat() == Format::FORMAT_FRACTAL_NZ) {
OP_LOGW("Format of self gets [%s], this format may lead to precision failure.",
op::ToString(self->GetStorageFormat()).GetString());
}
}
static aclnnStatus CheckParams(
const aclTensor* self, const aclIntArray* kernelSize, const aclIntArray* dilation, const aclIntArray* padding,
const aclIntArray* stride, const aclTensor* out)
{
CHECK_RET(CheckNotNull(self, kernelSize, dilation, padding, stride, out), ACLNN_ERR_PARAM_NULLPTR);
auto curArch = GetCurrentPlatformInfo().GetCurNpuArch();
if (IsRegBase(curArch)) {
OP_CHECK_DTYPE_NOT_SUPPORT(self, DTYPE_SUPPORT_LIST_REGBASE, return ACLNN_ERR_PARAM_INVALID);
} else {
OP_CHECK_DTYPE_NOT_SUPPORT(self, DTYPE_SUPPORT_LIST, return ACLNN_ERR_PARAM_INVALID);
}
CHECK_RET(CheckInputDims(self), ACLNN_ERR_PARAM_INVALID);
CHECK_RET(CheckArray(kernelSize, dilation, padding, stride), ACLNN_ERR_PARAM_INVALID);
CHECK_RET(CheckOutputDims(self, kernelSize, dilation, padding, stride, out), ACLNN_ERR_PARAM_INVALID);
CheckFormat(self);
return ACLNN_SUCCESS;
}
aclnnStatus aclnnIm2colGetWorkspaceSize(
const aclTensor* self, const aclIntArray* kernelSize, const aclIntArray* dilation, const aclIntArray* padding,
const aclIntArray* stride, const aclTensor* out, uint64_t* workspaceSize, aclOpExecutor** executor)
{
L2_DFX_PHASE_1(aclnnIm2col, DFX_IN(self, kernelSize, dilation, padding, stride), DFX_OUT(out));
auto uniqueExecutor = CREATE_EXECUTOR();
CHECK_RET(uniqueExecutor.get() != nullptr, ACLNN_ERR_INNER_CREATE_EXECUTOR);
auto ret = CheckParams(self, kernelSize, dilation, padding, stride, out);
CHECK_RET(ret == ACLNN_SUCCESS, ret);
if (self->IsEmpty()) {
*workspaceSize = 0;
uniqueExecutor.ReleaseTo(executor);
return ACLNN_SUCCESS;
}
bool isNeedSqueeze = (self->GetViewShape().GetDimNum() == NEED_SQUEEZE);
auto selfContiguous = l0op::Contiguous(self, uniqueExecutor.get());
CHECK_RET(selfContiguous != nullptr, ACLNN_ERR_INNER_NULLPTR);
auto selfUnsqueeze = isNeedSqueeze ?
l0op::UnsqueezeNd(selfContiguous, static_cast<int64_t>(0), uniqueExecutor.get()) :
selfContiguous;
CHECK_RET(selfUnsqueeze != nullptr, ACLNN_ERR_INNER_NULLPTR);
auto selfReFormat = l0op::ReFormat(selfUnsqueeze, op::Format::FORMAT_NCHW);
CHECK_RET(selfReFormat != nullptr, ACLNN_ERR_INNER_NULLPTR);
FVector<int64_t> padding4d = {(*padding)[0], (*padding)[0], (*padding)[1], (*padding)[1]};
const aclIntArray* newPadding = uniqueExecutor.get()->AllocIntArray(padding4d.data(), PADDING_SIZE);
auto im2colOut = l0op::Im2col(selfReFormat, kernelSize, dilation, newPadding, stride, uniqueExecutor.get());
CHECK_RET(im2colOut != nullptr, ACLNN_ERR_INNER_NULLPTR);
auto outSqueeze =
isNeedSqueeze ? l0op::SqueezeNd(im2colOut, static_cast<int64_t>(0), uniqueExecutor.get()) : im2colOut;
CHECK_RET(outSqueeze != nullptr, ACLNN_ERR_INNER_NULLPTR);
auto outView =
uniqueExecutor.get()->CreateView(outSqueeze, outSqueeze->GetViewShape(), outSqueeze->GetViewOffset());
CHECK_RET(outView != nullptr, ACLNN_ERR_INNER_NULLPTR);
auto outReFormat = l0op::ReFormat(outView, out->GetViewFormat());
CHECK_RET(outReFormat != nullptr, ACLNN_ERR_INNER_NULLPTR);
auto outCast = l0op::Cast(outReFormat, out->GetDataType(), uniqueExecutor.get());
CHECK_RET(outCast != nullptr, ACLNN_ERR_INNER_NULLPTR);
auto viewCopyResult = l0op::ViewCopy(outCast, out, uniqueExecutor.get());
CHECK_RET(viewCopyResult != nullptr, ACLNN_ERR_INNER_NULLPTR);
*workspaceSize = uniqueExecutor->GetWorkspaceSize();
uniqueExecutor.ReleaseTo(executor);
return ACLNN_SUCCESS;
}
aclnnStatus aclnnIm2col(void* workspace, uint64_t workspaceSize, aclOpExecutor* executor, aclrtStream stream)
{
L2_DFX_PHASE_2(aclnnIm2col);
return CommonOpExecutorRun(workspace, workspaceSize, executor, stream);
}
#ifdef __cplusplus
}
#endif