* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "aclnn_roll.h"
#include "roll.h"
#include "aclnn_kernels/common/op_error_check.h"
#include "opdev/common_types.h"
#include "opdev/data_type_utils.h"
#include "opdev/format_utils.h"
#include "opdev/op_dfx.h"
#include "opdev/op_executor.h"
#include "opdev/op_log.h"
#include "opdev/shape_utils.h"
#include "opdev/tensor_view_utils.h"
#include "aclnn_kernels/cast.h"
#include "aclnn_kernels/contiguous.h"
#include "aclnn_kernels/transpose.h"
#include "conversion/unsqueeze/op_host/op_api/unsqueeze.h"
#include "conversion/squeeze/op_host/op_api/squeeze.h"
#include "opdev/platform.h"
#include "op_api/aclnn_check.h"
#include "op_api/op_api_def.h"
using namespace op;
#ifdef __cplusplus
extern "C" {
#endif
static const std::initializer_list<op::DataType> ASCEND910_DTYPE_DTYPE_SUPPORT_LIST = {
op::DataType::DT_FLOAT, op::DataType::DT_FLOAT16, op::DataType::DT_INT8, op::DataType::DT_UINT8,
op::DataType::DT_INT32, op::DataType::DT_UINT32, op::DataType::DT_BOOL, op::DataType::DT_INT64};
static const std::initializer_list<op::DataType> ASCEND910B_DTYPE_DTYPE_SUPPORT_LIST = {
op::DataType::DT_FLOAT, op::DataType::DT_FLOAT16, op::DataType::DT_INT8,
op::DataType::DT_UINT8, op::DataType::DT_INT32, op::DataType::DT_UINT32,
op::DataType::DT_BOOL, op::DataType::DT_INT64, op::DataType::DT_BF16, op::DataType::DT_INT16};
* l1: ASCEND910B、ASCEND910_93 或者 ASCEND950 芯片,该算子支持的数据类型列表
* l2: 其他芯片,该算子支持的数据类型列表
*/
static const std::initializer_list<DataType>& GetDtypeSupportListV1(
const std::initializer_list<op::DataType>& l1, const std::initializer_list<op::DataType>& l2)
{
if (GetCurrentPlatformInfo().GetSocVersion() == SocVersion::ASCEND910B ||
GetCurrentPlatformInfo().GetSocVersion() == SocVersion::ASCEND910_93 ||
IsRegBase()) {
return l1;
} else {
return l2;
}
}
static inline bool CheckNotNull(
const aclTensor* x, const aclIntArray* shifts, const aclIntArray* dims, const aclTensor* out)
{
OP_CHECK_NULL(x, return false);
OP_CHECK_NULL(shifts, return false);
OP_CHECK_NULL(dims, return false);
OP_CHECK_NULL(out, return false);
return true;
}
static inline bool CheckDtypeValid(const aclTensor* x, const aclTensor* out)
{
const auto& supportList =
GetDtypeSupportListV1(ASCEND910B_DTYPE_DTYPE_SUPPORT_LIST, ASCEND910_DTYPE_DTYPE_SUPPORT_LIST);
OP_CHECK_DTYPE_NOT_SUPPORT(x, supportList, return false);
OP_CHECK_DTYPE_NOT_SUPPORT(out, supportList, return false);
OP_CHECK_DTYPE_NOT_MATCH(x, out->GetDataType(), return false);
return true;
}
static inline bool CheckShape(const aclTensor* x, const aclTensor* out)
{
OP_CHECK_SHAPE_NOT_EQUAL(x, out, return false);
return true;
}
static inline bool CheckArraySize(const aclIntArray* shifts, const aclIntArray* dims)
{
if (shifts->Size() != dims->Size() && dims->Size() != 0U) {
OP_LOGE(
ACLNN_ERR_PARAM_INVALID, "The size of shifts and dims should be the same when the size of dims is not 0.");
return false;
}
if (dims->Size() == 0U && shifts->Size() != 1U) {
OP_LOGE(ACLNN_ERR_PARAM_INVALID, "The size of shifts must be 1 if the size of dims is 0.");
return false;
}
return true;
}
static bool CheckDimsRange(const aclTensor* x, const aclIntArray* dims)
{
auto tensorDimSize = static_cast<int64_t>(x->GetViewShape().GetDimNum());
int64_t dimSize = static_cast<int64_t>(dims->Size());
for (int64_t i = 0; i < dimSize; i++) {
int64_t curDim = (*dims)[i];
auto dimMax = std::max(-1 * tensorDimSize, tensorDimSize - 1);
auto dimMin = std::min(-1 * tensorDimSize, tensorDimSize - 1);
if ((curDim > dimMax) || (curDim < dimMin)) {
OP_LOGE(ACLNN_ERR_PARAM_INVALID, "The values of dims should be in range [%ld, %ld].", dimMin, dimMax);
return false;
}
}
return true;
}
static inline bool CheckTensorDimSize(const aclTensor* x)
{
OP_CHECK_MAX_DIM(x, MAX_SUPPORT_DIMS_NUMS, return false);
return true;
}
static aclnnStatus CheckParams(
const aclTensor* x, const aclIntArray* shifts, const aclIntArray* dims, const aclTensor* out)
{
CHECK_RET(CheckNotNull(x, shifts, dims, out), ACLNN_ERR_PARAM_NULLPTR);
CHECK_RET(CheckDtypeValid(x, out), ACLNN_ERR_PARAM_INVALID);
CHECK_RET(CheckShape(x, out), ACLNN_ERR_PARAM_INVALID);
CHECK_RET(CheckArraySize(shifts, dims), ACLNN_ERR_PARAM_INVALID);
CHECK_RET(CheckTensorDimSize(x), ACLNN_ERR_PARAM_INVALID);
return ACLNN_SUCCESS;
}
static inline int64_t WrapDim(int64_t dim, uint32_t dimPostExpr)
{
if (dimPostExpr <= 0)
dimPostExpr = 1;
if (dim < 0)
dim += dimPostExpr;
return dim;
}
static inline aclIntArray* GetIntArray(int64_t x, aclOpExecutor* executor)
{
int64_t intArray[1] = {x};
auto res = executor->AllocIntArray(intArray, 1);
return res;
}
static const aclTensor* roll_transpose(const aclTensor* self, int64_t axis, int64_t shift, aclOpExecutor* executor)
{
auto selfContiguous = l0op::Contiguous(self, executor);
auto dimNow = GetIntArray(0, executor);
auto shiftNow = GetIntArray(shift, executor);
if (axis == 0) {
selfContiguous = l0op::Roll(selfContiguous, shiftNow, dimNow, executor);
return selfContiguous;
}
auto dimSize = static_cast<int64_t>(selfContiguous->GetViewShape().GetDimNum());
std::vector<int64_t> perm(dimSize);
for (int64_t i = 0; i < dimSize; i++) {
perm[i] = i;
}
std::swap(perm[axis], perm[0]);
auto valuePerm = executor->AllocIntArray(perm.data(), dimSize);
selfContiguous = l0op::Transpose(selfContiguous, valuePerm, executor);
CHECK_RET(selfContiguous != nullptr, nullptr);
selfContiguous = l0op::Roll(selfContiguous, shiftNow, dimNow, executor);
CHECK_RET(selfContiguous != nullptr, nullptr);
selfContiguous = l0op::Transpose(selfContiguous, valuePerm, executor);
return selfContiguous;
}
static aclnnStatus HandleDimZeroTensor(
const aclTensor* self, const aclIntArray* shifts, const aclIntArray* dims, const aclTensor* out,
aclOpExecutor* executor)
{
if (dims->Size() != 0 || shifts->Size() != 1) {
OP_LOGE(
ACLNN_ERR_PARAM_INVALID,
"When tensor x has no dimensions, shifts should be size 1, dims should be size 0.");
return ACLNN_ERR_PARAM_INVALID;
}
auto viewCopyRes = l0op::ViewCopy(self, out, executor);
CHECK_RET(viewCopyRes != nullptr, ACLNN_ERR_INNER_NULLPTR);
return ACLNN_SUCCESS;
}
aclnnStatus aclnnRollGetWorkspaceSize(
const aclTensor* x, const aclIntArray* shifts, const aclIntArray* dims, aclTensor* out, uint64_t* workspaceSize,
aclOpExecutor** executor)
{
OP_CHECK_COMM_INPUT(workspaceSize, executor);
L2_DFX_PHASE_1(aclnnRoll, DFX_IN(x, shifts, dims), DFX_OUT(out));
auto uniqueExecutor = CREATE_EXECUTOR();
CHECK_RET(uniqueExecutor.get() != nullptr, ACLNN_ERR_INNER_CREATE_EXECUTOR);
auto ret = CheckParams(x, shifts, dims, out);
CHECK_RET(ret == ACLNN_SUCCESS, ret);
if (x->IsEmpty()) {
*workspaceSize = 0;
uniqueExecutor.ReleaseTo(executor);
return ACLNN_SUCCESS;
}
if (x->GetViewShape().GetDimNum() == 0) {
auto res = HandleDimZeroTensor(x, shifts, dims, out, uniqueExecutor.get());
CHECK_RET(res == ACLNN_SUCCESS, res);
*workspaceSize = uniqueExecutor->GetWorkspaceSize();
uniqueExecutor.ReleaseTo(executor);
return ACLNN_SUCCESS;
}
CHECK_RET(CheckDimsRange(x, dims), ACLNN_ERR_PARAM_INVALID);
auto xContiguous = l0op::Contiguous(x, uniqueExecutor.get());
CHECK_RET(xContiguous != nullptr, ACLNN_ERR_INNER_NULLPTR);
bool needCastBool = (x->GetDataType() == DataType::DT_BOOL);
if (needCastBool) {
xContiguous = l0op::Cast(xContiguous, DataType::DT_INT8, uniqueExecutor.get());
CHECK_RET(xContiguous != nullptr, ACLNN_ERR_INNER_NULLPTR);
}
int64_t dimSize = static_cast<int64_t>(dims->Size());
auto tensorDim = static_cast<int64_t>(x->GetViewShape().GetDimNum());
auto newDims = dims;
auto newshifts = shifts;
bool hasDim1 = false;
int64_t shiftNum = 0;
for (int64_t i = 0; i < dimSize; i++) {
if (WrapDim((*dims)[i], tensorDim) == 1) {
hasDim1 = true;
shiftNum += (*shifts)[i];
}
}
bool needSqueeze = (x->GetViewShape().GetDimNum() == 2 && x->GetViewShape().GetDim(0) == 1 && hasDim1);
const int64_t appendDim[] = {0};
aclIntArray* dimArray = (uniqueExecutor.get())->AllocIntArray(appendDim, 1);
if (needSqueeze) {
xContiguous = l0op::SqueezeNd(xContiguous, dimArray, uniqueExecutor.get());
CHECK_RET(xContiguous != nullptr, ACLNN_ERR_INNER_NULLPTR);
newDims = GetIntArray(0, uniqueExecutor.get());
newshifts = GetIntArray(shiftNum, uniqueExecutor.get());
dimSize = 1;
}
const aclTensor* outBase = nullptr;
if (IsRegBase() || dimSize == 0) {
outBase = l0op::Roll(xContiguous, newshifts, newDims, uniqueExecutor.get());
CHECK_RET(outBase != nullptr, ACLNN_ERR_INNER_NULLPTR);
} else {
outBase = xContiguous;
for (int64_t i = 0; i < dimSize; i++) {
int64_t axis = WrapDim((*newDims)[i], tensorDim);
int64_t shift = (*newshifts)[i];
outBase = roll_transpose(outBase, axis, shift, uniqueExecutor.get());
CHECK_RET(outBase != nullptr, ACLNN_ERR_INNER_NULLPTR);
}
}
if (needSqueeze) {
outBase = l0op::UnsqueezeNd(outBase, dimArray, uniqueExecutor.get());
CHECK_RET(outBase != nullptr, ACLNN_ERR_INNER_NULLPTR);
}
if (needCastBool) {
outBase = l0op::Cast(outBase, DataType::DT_BOOL, uniqueExecutor.get());
CHECK_RET(outBase != nullptr, ACLNN_ERR_INNER_NULLPTR);
}
auto viewcopyResult = l0op::ViewCopy(outBase, out, uniqueExecutor.get());
CHECK_RET(viewcopyResult != nullptr, ACLNN_ERR_INNER_NULLPTR);
*workspaceSize = uniqueExecutor->GetWorkspaceSize();
uniqueExecutor.ReleaseTo(executor);
return ACLNN_SUCCESS;
}
aclnnStatus aclnnRoll(void* workspace, uint64_t workspaceSize, aclOpExecutor* executor, aclrtStream stream)
{
L2_DFX_PHASE_2(aclnnRoll);
return CommonOpExecutorRun(workspace, workspaceSize, executor, stream);
}
#ifdef __cplusplus
}
#endif