* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "aclnn_split_tensor.h"
#include "split_v.h"
#include "aclnn_kernels/cast.h"
#include "aclnn_kernels/contiguous.h"
#include "aclnn_kernels/slice.h"
#include "aclnn_kernels/common/op_error_check.h"
#include "opdev/common_types.h"
#include "opdev/data_type_utils.h"
#include "opdev/format_utils.h"
#include "opdev/op_executor.h"
#include "opdev/op_log.h"
#include "opdev/op_dfx.h"
#include "opdev/platform.h"
#include "opdev/shape_utils.h"
#include "opdev/tensor_view_utils.h"
#include "op_api/aclnn_check.h"
using namespace op;
#ifdef __cplusplus
extern "C" {
#endif
constexpr size_t MAX_DIM_LEN = 8;
constexpr int64_t SPLIT_LOOP_SIZE = 32;
constexpr int64_t SPLIT_LOOP_SIZE_512 = 512;
static const std::initializer_list<DataType> DTYPE_SUPPORT_LIST = {
DataType::DT_DOUBLE, DataType::DT_FLOAT, DataType::DT_FLOAT16, DataType::DT_BF16,
DataType::DT_INT64, DataType::DT_INT32, DataType::DT_INT16, DataType::DT_INT8,
DataType::DT_UINT8, DataType::DT_BOOL, DataType::DT_COMPLEX128, DataType::DT_COMPLEX64};
static const std::initializer_list<DataType> DTYPE_SUPPORT_LIST_950 = {
DataType::DT_DOUBLE, DataType::DT_FLOAT, DataType::DT_FLOAT16, DataType::DT_BF16,
DataType::DT_INT64, DataType::DT_UINT64, DataType::DT_INT32, DataType::DT_UINT32,
DataType::DT_INT16, DataType::DT_UINT16, DataType::DT_INT8, DataType::DT_UINT8,
DataType::DT_BOOL, DataType::DT_COMPLEX128, DataType::DT_COMPLEX64};
inline static bool CheckNotNull(const aclTensor *self, const aclTensorList *out) {
OP_CHECK_NULL(self, return false);
OP_CHECK_NULL(out, return false);
return true;
}
inline static bool CheckDtypeValid(const aclTensor *self, const aclTensorList *out) {
if (IsRegBase()) {
OP_CHECK_DTYPE_NOT_SUPPORT(self, DTYPE_SUPPORT_LIST_950, return false);
for (size_t index = 0; index < out->Size(); index++) {
OP_CHECK_DTYPE_NOT_SUPPORT((*out)[index], DTYPE_SUPPORT_LIST_950, return false);
}
return true;
}
OP_CHECK_DTYPE_NOT_SUPPORT(self, DTYPE_SUPPORT_LIST, return false);
for (size_t index = 0; index < out->Size(); index++) {
OP_CHECK_DTYPE_NOT_SUPPORT((*out)[index], DTYPE_SUPPORT_LIST, return false);
}
return true;
}
static bool CheckShape(const aclTensor *self, uint64_t splitSections, int64_t dim, const aclTensorList *out) {
OP_CHECK_MAX_DIM(self, MAX_DIM_LEN, return false);
OP_CHECK_MIN_DIM(self, 1, return false);
for (size_t index = 0; index < out->Size(); index++) {
OP_CHECK_MAX_DIM((*out)[index], MAX_DIM_LEN, return false);
}
int64_t selfDim = static_cast<int64_t>(self->GetViewShape().GetDimNum());
if ((dim >= 0 && dim >= selfDim) || (dim < 0 && dim < -selfDim)) {
OP_LOGE(ACLNN_ERR_PARAM_NULLPTR,
"Expected aclnnSplitTensor dim value [%ld] to be in range [%ld, %ld) but check failed.",
dim, -selfDim, selfDim);
return false;
}
size_t dimIndex = dim >= 0 ? static_cast<size_t>(dim) : static_cast<size_t>(dim + selfDim);
int64_t splitShape = self->GetViewShape().GetDim(dimIndex);
if (splitShape != 0 && splitSections == 0) {
OP_LOGE(ACLNN_ERR_PARAM_NULLPTR,
"Expected aclnnSplitTensor splitSections to not be zero while split dim size is not zero but got [%lu].",
splitSections);
return false;
}
if (splitShape == 0 && splitSections != 0) {
OP_LOGE(ACLNN_ERR_PARAM_NULLPTR,
"Expected aclnnSplitTensor splitSections to be zero while split dim size is zero but got [%lu].",
splitSections);
return false;
}
return true;
}
inline static aclnnStatus CheckParams(const aclTensor *self, uint64_t splitSections, int64_t dim,
const aclTensorList *out) {
CHECK_RET(CheckNotNull(self, out), ACLNN_ERR_PARAM_NULLPTR);
CHECK_RET(CheckDtypeValid(self, out), ACLNN_ERR_PARAM_INVALID);
CHECK_RET(CheckShape(self, splitSections, dim, out), ACLNN_ERR_PARAM_INVALID);
return ACLNN_SUCCESS;
}
inline static aclnnStatus SplitZeroCalculation(const aclTensor *self, aclTensorList *out, aclOpExecutor *executor) {
auto selfCast = l0op::Cast(self, (*out)[0]->GetDataType(), executor);
CHECK_RET(selfCast != nullptr, ACLNN_ERR_INNER_NULLPTR);
auto selfViewCopy = l0op::ViewCopy(selfCast, (*out)[0], executor);
CHECK_RET(selfViewCopy != nullptr, ACLNN_ERR_INNER_NULLPTR);
return ACLNN_SUCCESS;
}
static aclnnStatus SplitOnceCalculation(const aclTensor *self, const aclIntArray *splitSize, int64_t dim,
aclTensorList *out, aclOpExecutor *executor) {
auto splitRes = l0op::SplitV(self, splitSize, dim, executor);
if ((splitRes == nullptr) || (splitSize->Size() > out->Size()) || (splitSize->Size() > splitRes->Size())) {
OP_LOGE(ACLNN_ERR_PARAM_INVALID,
"Index value exceeds the out size range, splitSize->Size=%lu, splitRes->Size=%lu, out->Size=%lu",
splitSize->Size(), splitRes->Size(), out->Size());
return ACLNN_ERR_PARAM_INVALID;
}
for (size_t index = 0; index < splitSize->Size(); index++) {
CHECK_RET(CheckShapeAndScalarSame((*splitRes)[index], (*out)[index]), ACLNN_ERR_PARAM_INVALID);
auto splitCast = l0op::Cast((*splitRes)[index], (*out)[index]->GetDataType(), executor);
CHECK_RET(splitCast != nullptr, ACLNN_ERR_INNER_NULLPTR);
auto splitViewCopy = l0op::ViewCopy(splitCast, (*out)[index], executor);
CHECK_RET(splitViewCopy != nullptr, ACLNN_ERR_INNER_NULLPTR);
}
return ACLNN_SUCCESS;
}
static aclnnStatus SplitLoopCalculation(const aclTensor *self, const aclIntArray *splitSize, int64_t dim,
aclTensorList *out, aclOpExecutor *executor) {
const int64_t numSplit = splitSize->Size();
const int64_t splitLoopSize = (!IsRegBase()) ?
SPLIT_LOOP_SIZE : SPLIT_LOOP_SIZE_512;
const int64_t loopSize = (numSplit + splitLoopSize - 1) / splitLoopSize;
const int64_t lastSize = (numSplit % splitLoopSize == 0) ? splitLoopSize : numSplit % splitLoopSize;
op::Shape selfShape = self->GetViewShape();
const size_t selfDim = selfShape.GetDimNum();
FVector<int64_t> newSplitSize;
FVector<aclIntArray *> splitList;
for (int64_t loopIndex = 0; loopIndex < loopSize; loopIndex++) {
int64_t newSplit = 0;
FVector<int64_t> chunkVector;
int64_t currentSplitValue = 0;
if (loopIndex != loopSize - 1) {
for (int64_t noLastIndex = 0; noLastIndex < splitLoopSize; noLastIndex++) {
currentSplitValue = *(splitSize->GetData() + loopIndex * splitLoopSize + noLastIndex);
chunkVector.emplace_back(currentSplitValue);
newSplit += currentSplitValue;
}
} else {
for (int64_t lastIndex = 0; lastIndex < lastSize; lastIndex++) {
currentSplitValue = *(splitSize->GetData() + loopIndex * splitLoopSize + lastIndex);
chunkVector.emplace_back(currentSplitValue);
newSplit += currentSplitValue;
}
}
splitList.emplace_back(executor->AllocIntArray(chunkVector.data(), chunkVector.size()));
newSplitSize.emplace_back(newSplit);
}
FVector<const aclTensor *> splitTensorList;
int64_t offsetValue = 0;
for (size_t sliceIndex = 0; sliceIndex < newSplitSize.size(); sliceIndex++) {
FVector<int64_t> offsetVector(selfDim, 0);
offsetValue += sliceIndex == 0 ? 0 : newSplitSize[sliceIndex - 1];
offsetVector[static_cast<size_t>(dim)] = offsetValue;
aclIntArray *offsetArray = executor->AllocIntArray(offsetVector.data(), offsetVector.size());
FVector<int64_t> sizeVector;
for (size_t selfIndex = 0; selfIndex < selfDim; selfIndex++) {
int64_t sizeValue =
selfIndex == static_cast<size_t>(dim) ? newSplitSize[sliceIndex] : selfShape.GetDim(selfIndex);
sizeVector.emplace_back(sizeValue);
}
aclIntArray *sizeArray = executor->AllocIntArray(sizeVector.data(), sizeVector.size());
auto sliceRes = l0op::Slice(self, offsetArray, sizeArray, executor);
CHECK_RET(sliceRes != nullptr, ACLNN_ERR_INNER_NULLPTR);
auto splitRes = l0op::SplitV(sliceRes, splitList[sliceIndex], dim, executor);
CHECK_RET(splitRes != nullptr, ACLNN_ERR_INNER_NULLPTR);
for (int64_t resIndex = 0; resIndex < static_cast<int64_t>(splitRes->Size()); resIndex++) {
if ((resIndex + sliceIndex * splitLoopSize) >= out->Size()) {
OP_LOGE(ACLNN_ERR_PARAM_INVALID,
"Index value exceeds the out size range, resIndex=%ld, sliceIndex=%zu, out->Size=%lu",
resIndex, sliceIndex, out->Size());
return ACLNN_ERR_PARAM_INVALID;
}
auto splitCast = l0op::Cast((*splitRes)[resIndex], (*out)[resIndex + sliceIndex * splitLoopSize]->GetDataType(),
executor);
CHECK_RET(splitCast != nullptr, ACLNN_ERR_INNER_NULLPTR);
auto splitViewCopy = l0op::ViewCopy(splitCast, (*out)[resIndex + sliceIndex * splitLoopSize], executor);
CHECK_RET(splitViewCopy != nullptr, ACLNN_ERR_INNER_NULLPTR);
}
}
return ACLNN_SUCCESS;
}
aclnnStatus aclnnSplitTensorGetWorkspaceSize(const aclTensor *self, uint64_t splitSections, int64_t dim,
aclTensorList *out, uint64_t *workspaceSize, aclOpExecutor **executor) {
L2_DFX_PHASE_1(aclnnSplitTensor, DFX_IN(self, splitSections, dim), DFX_OUT(out));
auto uniqueExecutor = CREATE_EXECUTOR();
CHECK_RET(uniqueExecutor.get() != nullptr, ACLNN_ERR_INNER_CREATE_EXECUTOR);
auto ret = CheckParams(self, splitSections, dim, out);
CHECK_RET(ret == ACLNN_SUCCESS, ret);
if (dim < 0) {
dim += static_cast<int64_t>(self->GetViewShape().GetDimNum());
}
int64_t dimSize = self->GetViewShape().GetDim(static_cast<size_t>(dim));
if (self->IsEmpty()) {
*workspaceSize = 0;
uniqueExecutor.ReleaseTo(executor);
return ACLNN_SUCCESS;
}
auto selfContiguous = l0op::Contiguous(self, uniqueExecutor.get());
CHECK_RET(selfContiguous != nullptr, ACLNN_ERR_INNER_NULLPTR);
if (dimSize == static_cast<int64_t>(splitSections)) {
ret = SplitZeroCalculation(selfContiguous, out, uniqueExecutor.get());
} else {
int64_t numSplit = (dimSize + static_cast<int64_t>(splitSections) - 1) / static_cast<int64_t>(splitSections);
int64_t lastSplitSize = splitSections - (static_cast<int64_t>(splitSections) * numSplit - dimSize);
FVector<int64_t> splitVector(numSplit, static_cast<int64_t>(splitSections));
splitVector[numSplit - 1] = lastSplitSize;
aclIntArray *splitSize = uniqueExecutor.get()->AllocIntArray(splitVector.data(), splitVector.size());
if (l0op::SplitVAiCoreSupport(selfContiguous) && splitSize->Size() > SPLIT_LOOP_SIZE &&
!IsRegBase()) {
ret = SplitLoopCalculation(selfContiguous, splitSize, dim, out, uniqueExecutor.get());
} else if (splitSize->Size() > SPLIT_LOOP_SIZE_512) {
ret = SplitLoopCalculation(selfContiguous, splitSize, dim, out, uniqueExecutor.get());
} else {
ret = SplitOnceCalculation(selfContiguous, splitSize, dim, out, uniqueExecutor.get());
}
}
CHECK_RET(ret == ACLNN_SUCCESS, ret);
*workspaceSize = uniqueExecutor->GetWorkspaceSize();
uniqueExecutor.ReleaseTo(executor);
return ACLNN_SUCCESS;
}
aclnnStatus aclnnSplitTensor(void *workspace, uint64_t workspaceSize, aclOpExecutor *executor, aclrtStream stream) {
L2_DFX_PHASE_2(aclnnSplitTensor);
return CommonOpExecutorRun(workspace, workspaceSize, executor, stream);
}
#ifdef __cplusplus
}
#endif