* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "aclnn_stack.h"
#include "aclnn_kernels/cast.h"
#include "aclnn_kernels/common/op_error_check.h"
#include "aclnn_kernels/contiguous.h"
#include "op_api/aclnn_check.h"
#include "../../concat_d/op_api/concat_d.h"
#include "opdev/common_types.h"
#include "opdev/data_type_utils.h"
#include "opdev/format_utils.h"
#include "opdev/op_dfx.h"
#include "opdev/op_executor.h"
#include "opdev/op_log.h"
#include "opdev/shape_utils.h"
#include "opdev/tensor_view_utils.h"
#include "pack.h"
using namespace op;
#ifdef __cplusplus
extern "C" {
#endif
static const std::initializer_list<op::DataType> ASCEND910_DTYPE_DTYPE_SUPPORT_LIST = {
op::DataType::DT_INT8, op::DataType::DT_INT16, op::DataType::DT_INT32, op::DataType::DT_INT64,
op::DataType::DT_UINT8, op::DataType::DT_UINT16, op::DataType::DT_UINT32, op::DataType::DT_UINT64,
op::DataType::DT_FLOAT16, op::DataType::DT_FLOAT, op::DataType::DT_BOOL, op::DataType::DT_DOUBLE,
op::DataType::DT_COMPLEX64, op::DataType::DT_COMPLEX128};
static const std::initializer_list<op::DataType> ASCEND910B_DTYPE_DTYPE_SUPPORT_LIST = {
op::DataType::DT_INT8, op::DataType::DT_INT16, op::DataType::DT_INT32, op::DataType::DT_INT64,
op::DataType::DT_UINT8, op::DataType::DT_UINT16, op::DataType::DT_UINT32, op::DataType::DT_UINT64,
op::DataType::DT_FLOAT16, op::DataType::DT_FLOAT, op::DataType::DT_BOOL, op::DataType::DT_DOUBLE,
op::DataType::DT_COMPLEX64, op::DataType::DT_COMPLEX128, op::DataType::DT_BF16};
static bool CheckNotNull(const aclTensorList* tensors, const int64_t* realDim, const aclTensor* out)
{
if (tensors == nullptr || realDim == nullptr) {
OP_LOGE(ACLNN_ERR_INNER_NULLPTR, "Input of aclnnStack should not be null.");
return false;
}
OP_CHECK_NULL(out, return false);
return true;
}
static const std::initializer_list<DataType>& GetDtypeSupportList()
{
if (GetCurrentPlatformInfo().GetSocVersion() == SocVersion::ASCEND910B ||
GetCurrentPlatformInfo().GetSocVersion() == SocVersion::ASCEND910_93 ||
IsRegBase()) {
return ASCEND910B_DTYPE_DTYPE_SUPPORT_LIST;
} else {
return ASCEND910_DTYPE_DTYPE_SUPPORT_LIST;
}
}
static bool CheckDtypeValid(const aclTensorList* tensors, const aclTensor* out)
{
auto supportList = GetDtypeSupportList();
for (uint64_t i = 0; i < tensors->Size(); i++) {
if (!CheckType((*tensors)[i]->GetDataType(), supportList)) {
OP_LOGE(
ACLNN_ERR_PARAM_INVALID, "tensor %lu not implemented for %s, should be in dtype support list [%s].", i,
op::ToString((*tensors)[i]->GetDataType()).GetString(), op::ToString(supportList).GetString());
return false;
}
}
OP_CHECK_DTYPE_NOT_SUPPORT(out, supportList, return false);
return true;
}
static bool CheckPromoteType(const aclTensorList* tensors, const aclTensor* out)
{
op::DataType promoteType = (*tensors)[0]->GetDataType();
for (uint64_t i = 1; i < tensors->Size(); i++) {
promoteType = op::PromoteType((*tensors)[i]->GetDataType(), promoteType);
if (promoteType == DataType::DT_UNDEFINED) {
OP_LOGE(
ACLNN_ERR_PARAM_INVALID, "tensor %lu dtype %s and dtype %s can not promote dtype.", i,
op::ToString((*tensors)[i]->GetDataType()).GetString(), op::ToString(promoteType).GetString());
return false;
}
}
OP_CHECK_RESULT_DTYPE_CAST_FAILED(promoteType, out->GetDataType(), return false);
return true;
}
static bool CheckShape(const aclTensorList* tensors, int64_t* realDim)
{
op::Shape shapeFirst = (*tensors)[0]->GetViewShape();
auto dimNum = (int64_t)shapeFirst.GetDimNum();
static const int64_t MAX_DIM_LEN = 8;
if (dimNum > MAX_DIM_LEN) {
OP_LOGE(ACLNN_ERR_PARAM_INVALID, "Dim of self is %ld, can't be greater than %ld.", dimNum, MAX_DIM_LEN);
return false;
}
auto originDim = *realDim;
if (*realDim < 0) {
(*realDim) += dimNum + 1;
}
if ((*realDim) < 0 || (*realDim) > dimNum) {
OP_LOGE(
ACLNN_ERR_PARAM_INVALID, "dimnum %ld exceed the dim range of the [%ld, %ld].", originDim, -(dimNum + 1),
dimNum);
return false;
}
for (uint64_t i = 1; i < tensors->Size(); i++) {
op::Shape shapeCurrent = (*tensors)[i]->GetViewShape();
if (dimNum != (int64_t)shapeCurrent.GetDimNum()) {
OP_LOGE(
ACLNN_ERR_PARAM_INVALID, "dimnum of tensor %lu is [%zu], should be equal to tensor 0 [%ld].", i,
shapeCurrent.GetDimNum(), dimNum);
return false;
}
for (int64_t j = 0; j < dimNum; j++) {
if (shapeFirst.GetDim(j) != shapeCurrent.GetDim(j)) {
OP_LOGE(
ACLNN_ERR_PARAM_INVALID, "dim %ld of tensor %lu is [%ld], should be equal to tensor 0 [%ld].", j, i,
shapeCurrent.GetDim(j), shapeFirst.GetDim(j));
return false;
}
}
}
return true;
}
static aclnnStatus CheckParams(const aclTensorList* tensors, int64_t* realDim, const aclTensor* out)
{
CHECK_RET(CheckNotNull(tensors, realDim, out), ACLNN_ERR_INNER_NULLPTR);
CHECK_RET(CheckDtypeValid(tensors, out), ACLNN_ERR_PARAM_INVALID);
CHECK_RET(CheckPromoteType(tensors, out), ACLNN_ERR_PARAM_INVALID);
CHECK_RET(CheckShape(tensors, realDim), ACLNN_ERR_PARAM_INVALID);
return ACLNN_SUCCESS;
}
static bool EmplaceTensorList(
const aclTensorList* tensors, op::FVector<const aclTensor*>& tensorListA, DataType& promoteType)
{
tensorListA.emplace_back((*tensors)[0]);
for (uint64_t i = 1; i < tensors->Size(); i++) {
promoteType = op::PromoteType((*tensors)[i]->GetDataType(), promoteType);
tensorListA.emplace_back((*tensors)[i]);
}
return true;
}
static aclnnStatus SplitToStack(
const aclTensorList* tensors, int64_t dim, aclOpExecutor* executor, const aclTensor** out)
{
size_t maxInputs = 32;
op::FVector<const aclTensor*> tensorListA;
auto promoteType = (*tensors)[0]->GetDataType();
bool firstLoop = EmplaceTensorList(tensors, tensorListA, promoteType);
while (tensorListA.size() > 1) {
op::FVector<const aclTensor*> tensorListOnce;
op::FVector<const aclTensor*> tensorLlistB;
for (auto tensor : tensorListA) {
if (tensor->IsEmpty()) {
continue;
}
if (firstLoop) {
auto contiguous = l0op::Contiguous(tensor, executor);
CHECK_RET(contiguous != nullptr, ACLNN_ERR_INNER_NULLPTR);
auto castOut = l0op::Cast(contiguous, promoteType, executor);
CHECK_RET(castOut != nullptr, ACLNN_ERR_INNER_NULLPTR);
tensorListOnce.emplace_back(castOut);
} else {
tensorListOnce.emplace_back(tensor);
}
if (tensorListOnce.size() == maxInputs) {
auto tensorList = executor->AllocTensorList(tensorListOnce.data(), tensorListOnce.size());
const aclTensor* stackTensor = nullptr;
if (firstLoop) {
stackTensor = l0op::Pack(tensorList, dim, promoteType, executor);
} else {
stackTensor = l0op::ConcatD(tensorList, dim, promoteType, executor);
}
CHECK_RET(stackTensor != nullptr, ACLNN_ERR_INNER_NULLPTR);
tensorLlistB.emplace_back(stackTensor);
tensorListOnce.clear();
}
}
if (!tensorListOnce.empty()) {
auto aclTensorListTail = executor->AllocTensorList(tensorListOnce.data(), tensorListOnce.size());
const aclTensor* stackTensorTail = nullptr;
if (firstLoop) {
stackTensorTail = l0op::Pack(aclTensorListTail, dim, promoteType, executor);
} else {
stackTensorTail = l0op::ConcatD(aclTensorListTail, dim, promoteType, executor);
}
CHECK_RET(stackTensorTail != nullptr, ACLNN_ERR_INNER_NULLPTR);
tensorLlistB.emplace_back(stackTensorTail);
tensorListOnce.clear();
}
tensorListA = tensorLlistB;
firstLoop = false;
}
*out = tensorListA.empty() ? nullptr : tensorListA.front();
return ACLNN_SUCCESS;
}
aclnnStatus aclnnStackGetWorkspaceSize(
const aclTensorList* tensors, int64_t dim, aclTensor* out, uint64_t* workspaceSize, aclOpExecutor** executor)
{
L2_DFX_PHASE_1(aclnnStack, DFX_IN(tensors, dim), DFX_OUT(out));
auto uniqueExecutor = CREATE_EXECUTOR();
CHECK_RET(uniqueExecutor.get() != nullptr, ACLNN_ERR_INNER_CREATE_EXECUTOR);
int64_t realDim = dim;
auto ret = CheckParams(tensors, &realDim, out);
CHECK_RET(ret == ACLNN_SUCCESS, ret);
const aclTensor* castIn = nullptr;
if (tensors->Size() == 1) {
if ((*tensors)[0]->IsEmpty()) {
*workspaceSize = 0;
uniqueExecutor.ReleaseTo(executor);
return ACLNN_SUCCESS;
}
auto contiguous = l0op::Contiguous((*tensors)[0], uniqueExecutor.get());
CHECK_RET(contiguous != nullptr, ACLNN_ERR_INNER_NULLPTR);
op::FVector<const aclTensor*> tensorList{contiguous};
auto oneTensor = uniqueExecutor.get()->AllocTensorList(tensorList.data(), tensorList.size());
castIn = l0op::Pack(oneTensor, realDim, (*tensors)[0]->GetDataType(), uniqueExecutor.get());
} else {
aclnnStatus retSplit = SplitToStack(tensors, realDim, uniqueExecutor.get(), &castIn);
CHECK_RET(retSplit == ACLNN_SUCCESS, retSplit);
}
if (castIn == nullptr) {
*workspaceSize = 0;
uniqueExecutor.ReleaseTo(executor);
return ACLNN_SUCCESS;
}
CHECK_RET(CheckShapeAndScalarSame(castIn, out), ACLNN_ERR_PARAM_INVALID);
auto castOut = l0op::Cast(castIn, out->GetDataType(), uniqueExecutor.get());
CHECK_RET(castOut != nullptr, ACLNN_ERR_INNER_NULLPTR);
auto viewCopyResult = l0op::ViewCopy(castOut, out, uniqueExecutor.get());
CHECK_RET(viewCopyResult != nullptr, ACLNN_ERR_INNER_NULLPTR);
*workspaceSize = uniqueExecutor->GetWorkspaceSize();
uniqueExecutor.ReleaseTo(executor);
return ACLNN_SUCCESS;
}
aclnnStatus aclnnStack(void* workspace, uint64_t workspaceSize, aclOpExecutor* executor, const aclrtStream stream)
{
L2_DFX_PHASE_2(aclnnStack);
return CommonOpExecutorRun(workspace, workspaceSize, executor, stream);
}
#ifdef __cplusplus
}
#endif