* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "aclnn_scale.h"
#include "scale.h"
#include "aclnn_kernels/cast.h"
#include "aclnn_kernels/contiguous.h"
#include "op_api/aclnn_check.h"
#include "opdev/op_log.h"
#include "opdev/op_dfx.h"
#include "opdev/common_types.h"
#include "opdev/data_type_utils.h"
#include "opdev/format_utils.h"
#include "opdev/make_op_executor.h"
#include "aclnn_kernels/common/op_error_check.h"
using namespace op;
static constexpr size_t MAX_DIM_LEN = 8;
static const std::initializer_list<DataType> X_DTYPE_SUPPORT_LIST_ASCEND910B = {
op::DataType::DT_FLOAT, op::DataType::DT_FLOAT16, op::DataType::DT_BF16};
static const std::initializer_list<DataType> X_DTYPE_SUPPORT_LIST_ASCEND310P = {op::DataType::DT_FLOAT,
op::DataType::DT_FLOAT16};
static const std::initializer_list<DataType> EMPTY_LIST = {};
static const std::initializer_list<DataType>& GetDtypeSupportList() {
SocVersion socVersion = GetCurrentPlatformInfo().GetSocVersion();
if (socVersion == SocVersion::ASCEND910_93 || socVersion == SocVersion::ASCEND910B || IsRegBase()) {
return X_DTYPE_SUPPORT_LIST_ASCEND910B;
} else if (socVersion == SocVersion::ASCEND310P) {
return X_DTYPE_SUPPORT_LIST_ASCEND310P;
} else {
OP_LOGE(ACLNN_ERR_RUNTIME_ERROR, "support for %s is not implemented", op::ToString(socVersion).GetString());
return EMPTY_LIST;
}
}
static inline bool CheckNotNull(const aclTensor* x, const aclTensor* scale, const aclTensor* out) {
OP_CHECK_NULL(x, return false);
OP_CHECK_NULL(scale, return false);
OP_CHECK_NULL(out, return false);
return true;
}
static bool CheckDtypeValid(const aclTensor* x, const aclTensor* scale, const aclTensor* bias, const aclTensor* out) {
OP_CHECK_DTYPE_NOT_SUPPORT(x, GetDtypeSupportList(), return false);
if (bias != nullptr) {
OP_CHECK_DTYPE_NOT_SAME(scale, bias, return false);
}
OP_CHECK_DTYPE_NOT_SAME(x, scale, return false);
OP_CHECK_DTYPE_NOT_SAME(x, out, return false);
return true;
}
static bool CheckShapeFromBlob(const aclTensor* x, const aclTensor* scale, int64_t axis, int64_t numAxes) {
if (numAxes == 0 && scale->GetViewShape().GetShapeSize() != 1) {
OP_LOGE(ACLNN_ERR_PARAM_INVALID, "when num_axes is 0, scale must have one num, but shape is %s.",
op::ToString(scale->GetViewShape()).GetString());
return false;
}
if (numAxes == 0) {
return true;
}
op::Shape xShape = x->GetViewShape();
int64_t xRank = static_cast<int64_t>(xShape.GetDimNum());
int64_t newAxis = axis >= 0 ? axis : xRank + axis;
int64_t scaleLen = numAxes == -1L ? xRank - newAxis : numAxes;
if ((scaleLen + newAxis) > xRank || scaleLen <= 0) {
OP_LOGE(ACLNN_ERR_PARAM_INVALID, "the sum of scale axis(%ld) and numAxes(%ld) must be <= xRank(%ld).", newAxis,
numAxes, xRank);
return false;
}
op::Shape expectScaleShape;
expectScaleShape.SetDimNum(scaleLen);
for (int64_t i = 0; i < scaleLen; i++) {
expectScaleShape.SetDim(i, xShape.GetDim(newAxis + i));
}
if (expectScaleShape != scale->GetViewShape()) {
OP_LOGE(ACLNN_ERR_PARAM_INVALID,
"the scale must be %s, when x is %s and attr[axis(%ld), numAxes(%ld)], but get %s.",
op::ToString(expectScaleShape).GetString(), op::ToString(xShape).GetString(), axis, numAxes,
op::ToString(scale->GetViewShape()).GetString());
return false;
}
return true;
}
static bool CheckShapeNoFromBlob(const aclTensor* x, const aclTensor* scale, int64_t axis) {
op::Shape xShape = x->GetViewShape();
op::Shape scaleShape = scale->GetViewShape();
int64_t xRank = static_cast<int64_t>(xShape.GetDimNum());
int64_t newAxis = axis >= 0 ? axis : xRank + axis;
int64_t scaleLen = static_cast<int64_t>(scaleShape.GetDimNum());
if ((scaleLen + newAxis) > xRank || scaleLen < 0) {
OP_LOGE(ACLNN_ERR_PARAM_INVALID, "the sum of scale axis(%ld) and scale_rank(%ld) must be <= xRank(%ld).", newAxis,
scaleLen, xRank);
return false;
}
if (scaleShape.GetShapeSize() == 1 && scaleLen <= 1) {
return true;
}
for (int64_t i = 0; i < scaleLen; i++) {
if (xShape.GetDim(newAxis + i) != scaleShape.GetDim(i)) {
OP_LOGE(ACLNN_ERR_PARAM_INVALID, "the scale shape(%s) must be the same with x shape(%s) from axis dim(%ld).",
op::ToString(scale->GetViewShape()).GetString(), op::ToString(xShape).GetString(), axis);
return false;
}
}
return true;
}
static bool CheckDim(const aclTensor* x, int64_t axis, int64_t numAxes) {
op::Shape xShape = x->GetViewShape();
int64_t xRank = static_cast<int64_t>(xShape.GetDimNum());
if (axis < -xRank || axis >= xRank) {
OP_LOGE(ACLNN_ERR_PARAM_INVALID, "axis must be within the range of [%ld, %ld), but it is %ld.", -xRank, xRank,
axis);
return false;
}
if (numAxes < -1L) {
OP_LOGE(ACLNN_ERR_PARAM_INVALID, "numAxes must be >= -1, but it is %ld.", numAxes);
return false;
}
return true;
}
static bool CheckShape(const aclTensor* x, const aclTensor* scale, const aclTensor* bias, int64_t axis,
int64_t numAxes, bool scaleFromBlob, const aclTensor* out) {
OP_CHECK_SHAPE_NOT_EQUAL(out, x, return false);
if (bias != nullptr) {
OP_CHECK_SHAPE_NOT_EQUAL(bias, scale, return false);
}
OP_CHECK_MAX_DIM(x, MAX_DIM_LEN, return false);
OP_CHECK_MAX_DIM(scale, MAX_DIM_LEN, return false);
CHECK_RET(CheckDim(x, axis, numAxes), false);
if (scaleFromBlob) {
CHECK_RET(CheckShapeFromBlob(x, scale, axis, numAxes), false);
} else {
CHECK_RET(CheckShapeNoFromBlob(x, scale, axis), false);
}
return true;
}
static aclnnStatus CheckParams(const aclTensor* x, const aclTensor* scale, const aclTensor* bias, int64_t axis,
int64_t numAxes, bool scaleFromBlob, const aclTensor* out) {
CHECK_RET(CheckNotNull(x, scale, out), ACLNN_ERR_PARAM_NULLPTR);
CHECK_RET(CheckDtypeValid(x, scale, bias, out), ACLNN_ERR_PARAM_INVALID);
CHECK_RET(CheckShape(x, scale, bias, axis, numAxes, scaleFromBlob, out), ACLNN_ERR_PARAM_INVALID);
return ACLNN_SUCCESS;
}
aclnnStatus aclnnScaleGetWorkspaceSize(const aclTensor* x, const aclTensor* scale, const aclTensor* bias,
int64_t axis, int64_t numAxes, bool scaleFromBlob,
aclTensor* y, uint64_t* workspaceSize, aclOpExecutor** executor) {
L2_DFX_PHASE_1(aclnnScale, DFX_IN(x, scale, bias, axis, numAxes, scaleFromBlob), DFX_OUT(y));
auto uniqueExecutor = CREATE_EXECUTOR();
CHECK_RET(uniqueExecutor.get() != nullptr, ACLNN_ERR_INNER_CREATE_EXECUTOR);
auto ret = CheckParams(x, scale, bias, axis, numAxes, scaleFromBlob, y);
CHECK_RET(ret == ACLNN_SUCCESS, ret);
if (x->IsEmpty()) {
*workspaceSize = 0UL;
uniqueExecutor.ReleaseTo(executor);
return ACLNN_SUCCESS;
}
auto xContiguous = l0op::Contiguous(x, uniqueExecutor.get());
CHECK_RET(xContiguous != nullptr, ACLNN_ERR_INNER_NULLPTR);
auto scaleContiguous = l0op::Contiguous(scale, uniqueExecutor.get());
CHECK_RET(scaleContiguous != nullptr, ACLNN_ERR_INNER_NULLPTR);
const aclTensor* biasContiguous = nullptr;
if (bias != nullptr) {
biasContiguous = l0op::Contiguous(bias, uniqueExecutor.get());
CHECK_RET(biasContiguous != nullptr, ACLNN_ERR_INNER_NULLPTR);
}
auto scaleResult = l0op::Scale(xContiguous, scaleContiguous, biasContiguous, axis, numAxes,
scaleFromBlob, uniqueExecutor.get());
CHECK_RET(scaleResult != nullptr, ACLNN_ERR_INNER_NULLPTR);
auto viewCopyResult = l0op::ViewCopy(scaleResult, y, uniqueExecutor.get());
CHECK_RET(viewCopyResult != nullptr, ACLNN_ERR_INNER_NULLPTR);
*workspaceSize = uniqueExecutor->GetWorkspaceSize();
uniqueExecutor.ReleaseTo(executor);
return ACLNN_SUCCESS;
}
aclnnStatus aclnnScale(void* workspace, uint64_t workspaceSize, aclOpExecutor* executor, aclrtStream stream) {
L2_DFX_PHASE_2(aclnnScale);
return CommonOpExecutorRun(workspace, workspaceSize, executor, stream);
}