* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include <bitset>
#include "reduce_logsumexp.h"
#include "math/add/op_api/add.h"
#include "math/sub/op_api/sub.h"
#include "conversion/squeeze/op_host/op_api/squeeze.h"
#include "math/reduce_max/op_api/reduce_max.h"
#include "conversion/masked_fill/op_api/masked_fill.h"
#include "math/abs/op_api/abs.h"
#include "math/equal/op_api/equal.h"
#include "aclnn_kernels/cast.h"
#include "conversion/fill/op_api/fill.h"
#include "aclnn_kernels/contiguous.h"
#include "aclnn/aclnn_base.h"
#include "aclnn_kernels/common/op_error_check.h"
#include "opdev/common_types.h"
#include "opdev/data_type_utils.h"
#include "opdev/shape_utils.h"
#include "opdev/format_utils.h"
#include "opdev/op_dfx.h"
#include "opdev/op_executor.h"
#include "opdev/op_log.h"
#include "opdev/tensor_view_utils.h"
#include "opdev/op_errno.h"
#include "op_api/aclnn_check.h"
#include "aclnn_logsumexp.h"
#include "op_api/level2_base_caculation.h"
using namespace op;
#ifdef __cplusplus
extern "C" {
#endif
constexpr size_t MAX_MASK_LEN = 64;
constexpr size_t MAX_DIM_LEN = 8;
static bool CheckNotNull(const aclTensor* self, const aclIntArray* dim, aclTensor* out) {
OP_CHECK_NULL(self, return false);
OP_CHECK_NULL(dim, return false);
OP_CHECK_NULL(out, return false);
return true;
}
static const std::initializer_list<op::DataType> INPUT_DTYPE_SUPPORT_LIST = {
op::DataType::DT_FLOAT, op::DataType::DT_INT32, op::DataType::DT_INT64, op::DataType::DT_FLOAT16,
op::DataType::DT_INT16, op::DataType::DT_INT8, op::DataType::DT_UINT8, op::DataType::DT_BOOL};
static const std::initializer_list<op::DataType> INPUT_DTYPE_SUPPORT_LIST_910B = {
op::DataType::DT_FLOAT, op::DataType::DT_INT32, op::DataType::DT_INT64, op::DataType::DT_FLOAT16,
op::DataType::DT_INT16, op::DataType::DT_INT8, op::DataType::DT_UINT8, op::DataType::DT_BOOL,
op::DataType::DT_BF16};
static const std::initializer_list<op::DataType> OUTPUT_DTYPE_SUPPORT_LIST = {
op::DataType::DT_FLOAT, op::DataType::DT_FLOAT16};
static const std::initializer_list<op::DataType> OUTPUT_DTYPE_SUPPORT_LIST_910B = {
op::DataType::DT_FLOAT, op::DataType::DT_FLOAT16, op::DataType::DT_BF16};
static bool CheckDtypeValid(const aclTensor* self, aclTensor* out) {
if (op::GetCurrentPlatformInfo().GetSocVersion() == op::SocVersion::ASCEND910B ||
op::GetCurrentPlatformInfo().GetSocVersion() == op::SocVersion::ASCEND910_93 ||
IsRegBase()) {
OP_CHECK_DTYPE_NOT_SUPPORT(self, INPUT_DTYPE_SUPPORT_LIST_910B, return false);
OP_CHECK_DTYPE_NOT_SUPPORT(out, OUTPUT_DTYPE_SUPPORT_LIST_910B, return false);
} else {
OP_CHECK_DTYPE_NOT_SUPPORT(self, INPUT_DTYPE_SUPPORT_LIST, return false);
OP_CHECK_DTYPE_NOT_SUPPORT(out, OUTPUT_DTYPE_SUPPORT_LIST, return false);
}
return true;
}
static bool CheckPromoteType(const aclTensor* self, aclTensor* out) {
OP_CHECK_RESULT_DTYPE_CAST_FAILED(self->GetDataType(), out->GetDataType(), return false);
return true;
}
static bool CheckDimValid(const aclTensor* self, const aclIntArray* dim) {
auto selfViewShape = self->GetViewShape();
auto selfDimNum = static_cast<int64_t>(selfViewShape.GetDimNum());
if (selfDimNum <= 0) {
selfDimNum = 1;
}
for (size_t i = 0; i < dim->Size(); i++) {
if (dim->operator[](i) >= selfDimNum || dim->operator[](i) < (-selfDimNum)) {
OP_LOGE(ACLNN_ERR_PARAM_INVALID, "provided dim %ld not in the range of input tensor size %ld.",
dim->operator[](i), selfDimNum);
return false;
}
}
uint64_t dimMask[64] = {0};
for (size_t i = 0; i < dim->Size(); i++) {
auto dimValue = dim->operator[](i);
if (dim->operator[](i) < 0) {
dimValue = dim->operator[](i) + selfDimNum;
}
if (dimMask[dimValue] == 1) {
OP_LOGE(ACLNN_ERR_PARAM_INVALID, "Dim %ld appears multiple times in the list of dims",
dim->operator[](i));
return false;
} else {
dimMask[dimValue] = 1;
}
}
return true;
}
static void ExpectShapeInferWithDimMask(
const op::Shape& selfShape, const aclIntArray* dim, bool keepDim, op::Shape& expectShape)
{
bitset<MAX_MASK_LEN64> dimMask = bitset<MAX_MASK_LEN64>();
if (dim->Size() == 0) {
dimMask.flip();
}
for (size_t i = 0; i < dim->Size(); i++) {
int64_t index = GetPosDimWithStd(dim->operator[](i), selfShape.GetDimNum());
dimMask.set(index);
}
for (size_t i = 0; i < selfShape.GetDimNum(); i++) {
if (!dimMask[i]) {
expectShape.AppendDim(selfShape.GetDim(i));
} else if (keepDim) {
expectShape.AppendDim(1);
}
}
}
static bool CheckShape(const aclTensor* self, aclTensor* out, const aclIntArray* dim, bool keepDim) {
OP_CHECK_MAX_DIM(self, MAX_DIM_LEN, return false);
OP_CHECK_MAX_DIM(out, MAX_DIM_LEN, return false);
op::Shape expectShape;
ExpectShapeInferWithDimMask(self->GetViewShape(), dim, keepDim, expectShape);
OP_CHECK_SHAPE_NOT_EQUAL_WITH_EXPECTED_SIZE(out, expectShape, return false);
return true;
}
static aclnnStatus CheckParams(const aclTensor* self, const aclIntArray* dim, aclTensor* out, bool keepDim) {
CHECK_RET(CheckNotNull(self, dim, out), ACLNN_ERR_PARAM_NULLPTR);
CHECK_RET(CheckDtypeValid(self, out), ACLNN_ERR_PARAM_INVALID);
CHECK_RET(CheckPromoteType(self, out), ACLNN_ERR_PARAM_INVALID);
CHECK_RET(CheckDimValid(self, dim), ACLNN_ERR_PARAM_INVALID);
CHECK_RET(CheckShape(self, out, dim, keepDim), ACLNN_ERR_PARAM_INVALID);
return ACLNN_SUCCESS;
}
static aclnnStatus FillScalar(aclTensor* out, float val, aclOpExecutor* executor)
{
FVector<int64_t> shape = FillScalarGetShapeValue(out);
auto dims = executor->ConvertToTensor(shape.data(), shape.size(), DataType::DT_INT64);
auto shapeArray = executor->AllocIntArray(shape.data(), shape.size());
FVector<float> valVector = {val};
auto valTensor = executor->ConvertToTensor(valVector.data(), valVector.size(), out->GetDataType());
auto fillOut = l0op::Fill(dims, valTensor, shapeArray, executor);
CHECK_RET(fillOut != nullptr, ACLNN_ERR_INNER_NULLPTR);
auto viewCopyResult = l0op::ViewCopy(fillOut, out, executor);
CHECK_RET(viewCopyResult != nullptr, ACLNN_ERR_INNER_NULLPTR);
return ACLNN_SUCCESS;
}
aclnnStatus aclnnLogSumExpGetWorkspaceSize(const aclTensor* self, const aclIntArray* dim, bool keepDim,
aclTensor* out, uint64_t* workspaceSize, aclOpExecutor** executor) {
OP_CHECK_COMM_INPUT(workspaceSize, executor);
L2_DFX_PHASE_1(aclnnLogSumExp, DFX_IN(self, dim, keepDim), DFX_OUT(out));
auto uniqueExecutor = CREATE_EXECUTOR();
CHECK_RET(uniqueExecutor.get() != nullptr, ACLNN_ERR_INNER_CREATE_EXECUTOR);
auto ret = CheckParams(self, dim, out, keepDim);
CHECK_RET(ret == ACLNN_SUCCESS, ret);
if (self->IsEmpty()) {
ret = FillScalar(out, -INFINITY, uniqueExecutor.get());
CHECK_RET(ret == ACLNN_SUCCESS, ret);
*workspaceSize = 0UL;
uniqueExecutor.ReleaseTo(executor);
return ret;
}
if(self->GetStorageFormat() != Format::FORMAT_ND){
OP_LOGW("Format only support ND");
}
auto selfContiguous = l0op::Contiguous(self, uniqueExecutor.get());
CHECK_RET(selfContiguous != nullptr, ACLNN_ERR_INNER_NULLPTR);
auto promoteType = self->GetDataType();
if (self->GetDataType() == op::DataType::DT_FLOAT16 || self->GetDataType() == op::DataType::DT_BF16 ||
IsIntegralType(self->GetDataType(), true)) {
promoteType = op::DataType::DT_FLOAT;
}
auto selfCasted = l0op::Cast(selfContiguous, promoteType, uniqueExecutor.get());
CHECK_RET(selfCasted != nullptr, ACLNN_ERR_INNER_NULLPTR);
if (dim->Size() == 0) {
op::Shape shape = self->GetViewShape();
size_t dimDum = shape.GetDimNum();
int64_t appendDim[dimDum];
for (uint64_t i = 0; i < dimDum; i++) {
appendDim[i] = static_cast<int64_t>(i);
}
dim = uniqueExecutor.get()->AllocIntArray(appendDim, dimDum);
}
auto selfMax = l0op::ReduceMax(selfCasted, dim, true, true, uniqueExecutor.get());
CHECK_RET(selfMax != nullptr, ACLNN_ERR_INNER_NULLPTR);
auto absSelfMax = l0op::Abs(selfMax, uniqueExecutor.get());
CHECK_RET(absSelfMax != nullptr, ACLNN_ERR_INNER_NULLPTR);
FVector<float> infVector = {INFINITY};
auto infTensor = uniqueExecutor.get()->ConvertToTensor(infVector.data(), infVector.size(),
absSelfMax->GetDataType());
auto infMask = l0op::Equal(absSelfMax, infTensor, uniqueExecutor.get());
CHECK_RET(infMask != nullptr, ACLNN_ERR_INNER_NULLPTR);
FVector<float> zeroVector = {0};
auto zeroTensor = uniqueExecutor.get()->ConvertToTensor(zeroVector.data(), zeroVector.size(),
selfMax->GetDataType());
auto infZeroSelfMax = l0op::MaskedFill(selfMax, infMask, zeroTensor, uniqueExecutor.get());
CHECK_RET(infZeroSelfMax != nullptr, ACLNN_ERR_INNER_NULLPTR);
auto selfSubMax = l0op::Sub(selfCasted, infZeroSelfMax, uniqueExecutor.get());
CHECK_RET(selfSubMax != nullptr, ACLNN_ERR_INNER_NULLPTR);
auto logSumExpOpOut = l0op::ReduceLogSumExp(selfSubMax, dim, keepDim, uniqueExecutor.get());
CHECK_RET(logSumExpOpOut != nullptr, ACLNN_ERR_INNER_NULLPTR);
if (!keepDim && (infZeroSelfMax->GetViewShape().GetDimNum() > 0)) {
infZeroSelfMax = l0op::SqueezeNd(infZeroSelfMax, dim, uniqueExecutor.get());
CHECK_RET(infZeroSelfMax != nullptr, ACLNN_ERR_INNER_NULLPTR);
}
auto logSumExpAddOut = l0op::Add(logSumExpOpOut, infZeroSelfMax, uniqueExecutor.get());
CHECK_RET(logSumExpAddOut != nullptr, ACLNN_ERR_INNER_NULLPTR);
auto logSumExpAddOutCasted = l0op::Cast(logSumExpAddOut, out->GetDataType(), uniqueExecutor.get());
CHECK_RET(logSumExpAddOutCasted != nullptr, ACLNN_ERR_INNER_NULLPTR);
auto viewCopyResult = l0op::ViewCopy(logSumExpAddOutCasted, out, uniqueExecutor.get());
CHECK_RET(viewCopyResult != nullptr, ACLNN_ERR_INNER_NULLPTR);
*workspaceSize = uniqueExecutor->GetWorkspaceSize();
uniqueExecutor.ReleaseTo(executor);
return ACLNN_SUCCESS;
}
aclnnStatus aclnnLogSumExp(void* workspace, uint64_t workspaceSize, aclOpExecutor* executor, aclrtStream stream) {
L2_DFX_PHASE_2(aclnnLogSumExp);
return CommonOpExecutorRun(workspace, workspaceSize, executor, stream);
}
#ifdef __cplusplus
}
#endif