* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file expand.cpp
* \brief
*/
#include "expand.h"
#include "opdev/aicpu/aicpu_task.h"
#include "opdev/make_op_executor.h"
#include "opdev/op_dfx.h"
#include "opdev/tensor_view_utils.h"
using namespace op;
namespace l0op {
OP_TYPE_REGISTER(Expand);
static const std::initializer_list<DataType> AICORE_DTYPE_SUPPORT_LIST = {
DataType::DT_FLOAT, DataType::DT_FLOAT16, DataType::DT_INT32, DataType::DT_UINT8,
DataType::DT_INT8, DataType::DT_BOOL, DataType::DT_INT64};
static const std::initializer_list<DataType> ASCEND910B_AICORE_DTYPE_SUPPORT_LIST = {
DataType::DT_FLOAT, DataType::DT_FLOAT16, DataType::DT_INT32, DataType::DT_UINT8,
DataType::DT_INT8, DataType::DT_BOOL, DataType::DT_BF16, DataType::DT_INT64};
static bool IsAiCoreSupport(const aclTensor* self)
{
auto curArch = GetCurrentPlatformInfo().GetCurNpuArch();
if ((curArch == NpuArch::DAV_2201) || (curArch == NpuArch::DAV_3510)) {
return CheckType(self->GetDataType(), ASCEND910B_AICORE_DTYPE_SUPPORT_LIST);
}
return CheckType(self->GetDataType(), AICORE_DTYPE_SUPPORT_LIST);
}
static const aclTensor* ExpandAiCore(
const aclTensor* self, const aclTensor* shape, const aclTensor* expandOut, aclOpExecutor* executor)
{
L0_DFX(ExpandAiCore, self, shape, expandOut);
auto ret = ADD_TO_LAUNCHER_LIST_AICORE(Expand, OP_INPUT(self, shape), OP_OUTPUT(expandOut));
OP_CHECK(
ret == ACL_SUCCESS, OP_LOGE(ACLNN_ERR_INNER_NULLPTR, "ExpandAiCore ADD_TO_LAUNCHER_LIST_AICORE failed."),
return nullptr);
return expandOut;
}
static const aclTensor* ExpandAiCpu(
const aclTensor* self, const aclTensor* shape, const aclTensor* expandOut, aclOpExecutor* executor)
{
L0_DFX(ExpandAiCpu, self, shape);
static internal::AicpuTaskSpace space("Expand");
auto ret = ADD_TO_LAUNCHER_LIST_AICPU(Expand, OP_ATTR_NAMES(), OP_INPUT(self, shape), OP_OUTPUT(expandOut));
OP_CHECK(
ret == ACL_SUCCESS, OP_LOGE(ACLNN_ERR_INNER_NULLPTR, "ExpandAiCpu ADD_TO_LAUNCHER_LIST_AICPU failed."),
return nullptr);
return expandOut;
}
bool ExpandInfershape(const aclTensor* x, const op::Shape& shape, op::Shape& expandShape)
{
if (x->GetViewShape() == shape && x->GetStorageShape() == shape && x->GetOriginalShape() == shape) {
expandShape = shape;
return true;
}
auto expandDimNum = static_cast<int64_t>(shape.GetDimNum());
auto xShape = x->GetViewShape();
auto xDimNum = static_cast<int64_t>(xShape.GetDimNum());
expandShape.SetDimNum(expandDimNum);
int64_t lenDiff = expandDimNum - xDimNum;
if (lenDiff < 0) {
OP_LOGE(ACLNN_ERR_PARAM_INVALID, "length of shape should not be less than input_x's dimension.");
return false;
}
for (int64_t i = 0; i < xDimNum; i++) {
if (shape[lenDiff + i] == -1) {
expandShape[lenDiff + i] = xShape[i];
}
if ((xShape[i] != shape[lenDiff + i]) && (xShape[i] != 1)) {
OP_LOGE(ACL_ERROR_INVALID_PARAM, "shape cannot be expanded to an incompatible value.");
return false;
}
expandShape[lenDiff + i] = shape[lenDiff + i];
}
for (int64_t i = 0; i < lenDiff; i++) {
expandShape[i] = shape[i];
}
return true;
}
const aclTensor* Expand(const aclTensor* self, const aclIntArray* shape, aclOpExecutor* executor)
{
L0_DFX(Expand, self, shape);
if (!op::IsContiguous(self)) {
OP_LOGE(ACLNN_ERR_PARAM_INVALID, "input tensor should be contiguous.");
return nullptr;
}
op::Shape newShape;
for (int64_t i = 0; i < static_cast<int64_t>(shape->Size()); i++) {
newShape.AppendDim((*shape)[i]);
}
op::Shape expandShape;
if (!ExpandInfershape(self, newShape, expandShape)) {
OP_LOGE(ACL_ERROR_FAILURE, "InferShape failed for expand operation.");
return nullptr;
}
auto expandOut = executor->AllocTensor(expandShape, self->GetDataType());
CHECK_RET(expandOut != nullptr, nullptr);
auto shapeTensor = executor->ConvertToTensor(shape, DataType::DT_INT64);
CHECK_RET(shapeTensor != nullptr, nullptr);
if (IsAiCoreSupport(self)) {
return ExpandAiCore(self, shapeTensor, expandOut, executor);
} else {
return ExpandAiCpu(self, shapeTensor, expandOut, executor);
}
}
}