* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "aclnn_strided_slice.h"
#include "strided_slice.h"
#include "aclnn_kernels/contiguous.h"
#include "opdev/make_op_executor.h"
#include "opdev/op_dfx.h"
#include "opdev/op_log.h"
#include "opdev/platform.h"
#include "aclnn_kernels/common/op_error_check.h"
#include "op_api/aclnn_check.h"
#include "op_api/op_api_def.h"
#include "aclnn_kernels/transdata.h"
using namespace op;
#ifdef __cplusplus
extern "C" {
#endif
static constexpr size_t ARRAY_SIZE = 0;
static const std::initializer_list<op::DataType> DTYPE_SUPPORT_LIST = {
op::DataType::DT_FLOAT16, op::DataType::DT_FLOAT, op::DataType::DT_INT32, op::DataType::DT_UINT8,
op::DataType::DT_BOOL, op::DataType::DT_INT8, op::DataType::DT_INT16, op::DataType::DT_INT64,
op::DataType::DT_UINT16, op::DataType::DT_UINT32, op::DataType::DT_UINT64, op::DataType::DT_BF16,
op::DataType::DT_COMPLEX32, op::DataType::DT_COMPLEX64, op::DataType::DT_HIFLOAT8, op::DataType::DT_FLOAT8_E5M2,
ge::DT_FLOAT8_E4M3FN};
static inline bool CheckNotNull(
const aclTensor* self, const aclIntArray* begin, const aclIntArray* end, const aclIntArray* strides, aclTensor* out)
{
OP_CHECK_NULL(self, return false);
OP_CHECK_NULL(begin, return false);
OP_CHECK_NULL(end, return false);
OP_CHECK_NULL(strides, return false);
OP_CHECK_NULL(out, return false);
return true;
}
static bool CheckDtypeValid(const aclTensor* self, aclTensor* out)
{
if (IsRegBase()) {
OP_CHECK_DTYPE_NOT_SUPPORT(self, DTYPE_SUPPORT_LIST, return false);
OP_CHECK_DTYPE_NOT_SUPPORT(out, DTYPE_SUPPORT_LIST, return false);
OP_CHECK_DTYPE_NOT_SAME(self, out, return false);
} else {
OP_LOGE(ACLNN_ERR_PARAM_INVALID, "aclnnStridedSlice only support ASCEND950.");
return false;
}
return true;
}
static bool CheckInputDims(const aclTensor* self)
{
OP_CHECK_MAX_DIM(self, MAX_SUPPORT_DIMS_NUMS, return false);
return true;
}
static bool CheckArray(const aclIntArray* begin, const aclIntArray* end, const aclIntArray* strides)
{
if (begin->Size() != end->Size()) {
OP_LOGE(
ACLNN_ERR_PARAM_INVALID, "Expected aclnnStridedSlice begin.size() %lu to be equal to end.size() %lu.",
begin->Size(), end->Size());
return false;
}
if (end->Size() != strides->Size()) {
OP_LOGE(
ACLNN_ERR_PARAM_INVALID, "Expected aclnnStridedSlice end.size() %lu to be equal to strides.size() %lu.",
end->Size(), strides->Size());
return false;
}
for (size_t i = 0; i < strides->Size(); i++) {
if ((*strides)[i] == 0) {
OP_LOGE(
ACLNN_ERR_PARAM_INVALID, "Expected strides value must not be zero, but strides No.[%lu] value is zero.",
i);
return false;
}
}
return true;
}
static bool CheckInputMask(const aclIntArray* strides, int64_t ellipsisMask, int64_t shrinkAxisMask)
{
if ((ellipsisMask != 0) && ((ellipsisMask & (ellipsisMask - 1)) != 0)) {
OP_LOGE(ACLNN_ERR_PARAM_INVALID, "Multiple ellipses in slice spec not allowed.");
return false;
}
for (size_t i = 0; i < strides->Size(); i++) {
if ((shrinkAxisMask >> i) & 1) {
if ((*strides)[i] <= 0) {
OP_LOGE(
ACLNN_ERR_PARAM_INVALID,
"Strides must be positive when shrinkAxisMask has bit set at dimension [%lu].", i);
return false;
}
}
}
return true;
}
static aclnnStatus CheckParams(
const aclTensor* self, const aclIntArray* begin, const aclIntArray* end, const aclIntArray* strides,
int64_t ellipsisMask, int64_t shrinkAxisMask, aclTensor* out)
{
CHECK_RET(CheckNotNull(self, begin, end, strides, out), ACLNN_ERR_PARAM_NULLPTR);
CHECK_RET(CheckDtypeValid(self, out), ACLNN_ERR_PARAM_INVALID);
CHECK_RET(CheckInputDims(self), ACLNN_ERR_PARAM_INVALID);
CHECK_RET(CheckArray(begin, end, strides), ACLNN_ERR_PARAM_INVALID);
CHECK_RET(CheckInputMask(strides, ellipsisMask, shrinkAxisMask), ACLNN_ERR_PARAM_INVALID);
return ACLNN_SUCCESS;
}
aclnnStatus aclnnStridedSliceGetWorkspaceSize(
const aclTensor* self, const aclIntArray* begin, const aclIntArray* end, const aclIntArray* strides,
int64_t beginMask, int64_t endMask, int64_t ellipsisMask, int64_t newAxisMask, int64_t shrinkAxisMask,
aclTensor* out, uint64_t* workspaceSize, aclOpExecutor** executor)
{
L2_DFX_PHASE_1(
aclnnStridedSlice,
DFX_IN(self, begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask), DFX_OUT(out));
auto uniqueExecutor = CREATE_EXECUTOR();
CHECK_RET(uniqueExecutor.get() != nullptr, ACLNN_ERR_INNER_CREATE_EXECUTOR);
auto ret = CheckParams(self, begin, end, strides, ellipsisMask, shrinkAxisMask, out);
CHECK_RET(ret == ACLNN_SUCCESS, ret);
if (self->IsEmpty()) {
*workspaceSize = 0;
uniqueExecutor.ReleaseTo(executor);
return ACLNN_SUCCESS;
}
auto selfContiguous = l0op::Contiguous(self, uniqueExecutor.get());
CHECK_RET(selfContiguous != nullptr, ACLNN_ERR_INNER_NULLPTR);
if (begin->Size() == ARRAY_SIZE) {
auto viewCopyResult = l0op::ViewCopy(selfContiguous, out, uniqueExecutor.get());
CHECK_RET(viewCopyResult != nullptr, ACLNN_ERR_INNER_NULLPTR);
} else {
auto selfReformat = l0op::ReFormat(selfContiguous, Format::FORMAT_ND, uniqueExecutor.get());
CHECK_RET(selfReformat != nullptr, ACLNN_ERR_INNER_NULLPTR);
auto beginTensor = uniqueExecutor.get()->ConvertToTensor(begin, op::ToOpDataType(ACL_INT64));
auto endTensor = uniqueExecutor.get()->ConvertToTensor(end, op::ToOpDataType(ACL_INT64));
auto stridesTensor = uniqueExecutor.get()->ConvertToTensor(strides, op::ToOpDataType(ACL_INT64));
auto stridedsliceOut = l0op::StridedSlice(
selfReformat, beginTensor, endTensor, stridesTensor, beginMask, endMask, ellipsisMask, newAxisMask,
shrinkAxisMask, uniqueExecutor.get());
CHECK_RET(stridedsliceOut != nullptr, ACLNN_ERR_INNER_NULLPTR);
CHECK_RET(CheckShapeAndScalarSame(stridedsliceOut, out), ACLNN_ERR_PARAM_INVALID);
auto viewCopyResult = l0op::ViewCopy(stridedsliceOut, out, uniqueExecutor.get());
CHECK_RET(viewCopyResult != nullptr, ACLNN_ERR_INNER_NULLPTR);
}
*workspaceSize = uniqueExecutor->GetWorkspaceSize();
uniqueExecutor.ReleaseTo(executor);
return ACLNN_SUCCESS;
}
aclnnStatus aclnnStridedSlice(void* workspace, uint64_t workspaceSize, aclOpExecutor* executor, aclrtStream stream)
{
L2_DFX_PHASE_2(aclnnStridedSlice);
return CommonOpExecutorRun(workspace, workspaceSize, executor, stream);
}
#ifdef __cplusplus
}
#endif