* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "aclnn_kl_div.h"
#include "kl_div.h"
#include "aclnn_kernels/cast.h"
#include "aclnn_kernels/contiguous.h"
#include "aclnn_kernels/common/op_error_check.h"
#include "opdev/common_types.h"
#include "opdev/data_type_utils.h"
#include "opdev/format_utils.h"
#include "opdev/op_dfx.h"
#include "opdev/op_executor.h"
#include "opdev/op_log.h"
#include "opdev/shape_utils.h"
#include "opdev/tensor_view_utils.h"
#include "conversion/fill/op_api/fill.h"
#include "conversion/unsqueeze/op_host/op_api/unsqueeze.h"
#include "conversion/broadcast_to/op_api/broadcast_to.h"
#include "math/real_div/op_api/realdiv.h"
#include "opdev/platform.h"
#include "op_api/aclnn_check.h"
using namespace op;
#ifdef __cplusplus
extern "C" {
#endif
enum Reduction
{
None = 0,
Mean = 1,
Sum = 2,
BatchMean = 3,
End
};
static std::string REDUCTION_NONE = "none";
static std::string REDUCTION_MEAN = "mean";
static std::string REDUCTION_SUM = "sum";
static std::string REDUCTION_BATCHMEAN = "batchmean";
constexpr size_t MAX_DIM_LEN = 8;
constexpr size_t BATCH_INDEX = 0;
static const std::initializer_list<op::DataType> ASCEND910_DTYPE_DTYPE_SUPPORT_LIST = {op::DataType::DT_FLOAT,
op::DataType::DT_FLOAT16};
static const std::initializer_list<op::DataType> ASCEND910B_DTYPE_DTYPE_SUPPORT_LIST = {op::DataType::DT_FLOAT,
op::DataType::DT_FLOAT16,
op::DataType::DT_BF16};
static bool CheckNotNull(const aclTensor* self, const aclTensor* target, const aclTensor* out) {
OP_CHECK_NULL(self, return false);
OP_CHECK_NULL(target, return false);
OP_CHECK_NULL(out, return false);
return true;
}
static void CheckFormat(const aclTensor* self, const aclTensor* target){
ge::Format selfStorageFormat = self->GetStorageFormat();
ge::Format targetStorageFormat = target->GetStorageFormat();
if (selfStorageFormat != ge::Format::FORMAT_ND || targetStorageFormat != ge::Format::FORMAT_ND){
OP_LOGW("aclnnKlDiv only support format ND.");
}
}
static const std::initializer_list<DataType>& GetDtypeSupportList() {
if (GetCurrentPlatformInfo().GetSocVersion() == SocVersion::ASCEND910B ||
GetCurrentPlatformInfo().GetSocVersion() == SocVersion::ASCEND910_93 ||
IsRegBase()) {
return ASCEND910B_DTYPE_DTYPE_SUPPORT_LIST;
} else {
return ASCEND910_DTYPE_DTYPE_SUPPORT_LIST;
}
}
static bool CheckDtypeValid(const aclTensor* self, const aclTensor* target, const aclTensor* out) {
const auto& supportList = GetDtypeSupportList();
OP_CHECK_DTYPE_NOT_SUPPORT(self, supportList, return false);
OP_CHECK_DTYPE_NOT_SUPPORT(target, supportList, return false);
OP_CHECK_DTYPE_NOT_SUPPORT(out, supportList, return false);
return true;
}
static bool CheckPromoteType(const aclTensor* self, const aclTensor* target, const aclTensor* out) {
op::DataType promoteType = op::PromoteType(self->GetDataType(), target->GetDataType());
if (promoteType == op::DataType::DT_UNDEFINED) {
OP_LOGE(ACLNN_ERR_PARAM_INVALID, "Self dtype %s and target dtype %s can not promote dtype.",
op::ToString(self->GetDataType()).GetString(), op::ToString(target->GetDataType()).GetString());
return false;
}
OP_CHECK_RESULT_DTYPE_CAST_FAILED(promoteType, out->GetDataType(), return false);
return true;
}
static bool CheckShape(const aclTensor* self, const aclTensor* target) {
OP_CHECK_MAX_DIM(self, MAX_DIM_LEN, return false);
OP_CHECK_MAX_DIM(target, MAX_DIM_LEN, return false);
return true;
}
static bool CheckReduction(int64_t reduction) {
if (reduction < 0 || reduction > 3) {
OP_LOGE(ACLNN_ERR_PARAM_INVALID,
"Invalid reduction value: %ld. Allowed values are 0 (None), 1 (Mean), 2 (Sum), 3 (Batchmean).", reduction);
return false;
}
return true;
}
static aclnnStatus CheckParams(const aclTensor* self, const aclTensor* target, aclTensor* out, int64_t reduction) {
CHECK_RET(CheckNotNull(self, target, out), ACLNN_ERR_PARAM_NULLPTR);
CHECK_RET(CheckDtypeValid(self, target, out), ACLNN_ERR_PARAM_INVALID);
CHECK_RET(CheckPromoteType(self, target, out), ACLNN_ERR_PARAM_INVALID);
CHECK_RET(CheckShape(self, target), ACLNN_ERR_PARAM_INVALID);
CHECK_RET(CheckReduction(reduction), ACLNN_ERR_PARAM_INVALID);
return ACLNN_SUCCESS;
}
static std::string &GetReductionStr(int64_t reduction) {
if (reduction == Mean) {
return REDUCTION_MEAN;
} else if (reduction == Sum) {
return REDUCTION_SUM;
} else if (reduction == BatchMean) {
return REDUCTION_BATCHMEAN;
}
return REDUCTION_NONE;
}
aclnnStatus FillScalar(aclTensor *out, float val, aclOpExecutor *executor) {
FVector<int64_t> shape;
shape.push_back(1);
auto dims = executor->ConvertToTensor(shape.data(), shape.size(), DataType::DT_INT64);
CHECK_RET(dims != nullptr, ACLNN_ERR_INNER_NULLPTR);
auto shapeArray = executor->AllocIntArray(shape.data(), shape.size());
CHECK_RET(shapeArray != nullptr, ACLNN_ERR_INNER_NULLPTR);
FVector<float> valVector = {val};
auto valTensor = executor->ConvertToTensor(valVector.data(), valVector.size(), out->GetDataType());
CHECK_RET(valTensor != nullptr, ACLNN_ERR_INNER_NULLPTR);
auto fillOut = l0op::Fill(dims, valTensor, shapeArray, executor);
CHECK_RET(fillOut != nullptr, ACLNN_ERR_INNER_NULLPTR);
auto viewCopyResult = l0op::ViewCopy(fillOut, out, executor);
CHECK_RET(viewCopyResult != nullptr, ACLNN_ERR_INNER_NULLPTR);
return ACLNN_SUCCESS;
}
aclnnStatus aclnnKlDivGetWorkspaceSize(const aclTensor* self, const aclTensor* target, int64_t reduction,
bool logTarget, aclTensor* out, uint64_t* workspaceSize,
aclOpExecutor** executor) {
L2_DFX_PHASE_1(aclnnKlDiv, DFX_IN(self, target, reduction, logTarget), DFX_OUT(out));
auto uniqueExecutor = CREATE_EXECUTOR();
CHECK_RET(uniqueExecutor.get() != nullptr, ACLNN_ERR_INNER_CREATE_EXECUTOR);
auto ret = CheckParams(self, target, out, reduction);
CHECK_RET(ret == ACLNN_SUCCESS, ret);
CheckFormat(self, target);
if (self->IsEmpty() || target->IsEmpty()) {
if (reduction == Mean) {
ret = FillScalar(out, NAN, uniqueExecutor.get());
CHECK_RET(ret == ACLNN_SUCCESS, ret);
*workspaceSize = uniqueExecutor->GetWorkspaceSize();
} else if (reduction == Sum) {
ret = FillScalar(out, 0, uniqueExecutor.get());
CHECK_RET(ret == ACLNN_SUCCESS, ret);
*workspaceSize = uniqueExecutor->GetWorkspaceSize();
} else if(reduction == BatchMean) {
if (self->GetViewShape().GetDim(BATCH_INDEX) == 0) {
ret = FillScalar(out, NAN, uniqueExecutor.get());
CHECK_RET(ret == ACLNN_SUCCESS, ret);
*workspaceSize = uniqueExecutor->GetWorkspaceSize();
} else {
ret = FillScalar(out, 0, uniqueExecutor.get());
CHECK_RET(ret == ACLNN_SUCCESS, ret);
*workspaceSize = uniqueExecutor->GetWorkspaceSize();
}
} else {
*workspaceSize = 0;
}
uniqueExecutor.ReleaseTo(executor);
return ACLNN_SUCCESS;
}
auto promoteType = op::PromoteType(self->GetDataType(), target->GetDataType());
OP_CHECK_RESULT_DTYPE_CAST_FAILED(promoteType, out->GetDataType(), return false);
auto selfContiguous = l0op::Contiguous(self, uniqueExecutor.get());
CHECK_RET(selfContiguous != nullptr, ACLNN_ERR_INNER_NULLPTR);
auto selfCasted = l0op::Cast(selfContiguous, promoteType, uniqueExecutor.get());
CHECK_RET(selfCasted != nullptr, ACLNN_ERR_INNER_NULLPTR);
auto targetContiguous = l0op::Contiguous(target, uniqueExecutor.get());
CHECK_RET(targetContiguous != nullptr, ACLNN_ERR_INNER_NULLPTR);
auto targetCasted = l0op::Cast(targetContiguous, promoteType, uniqueExecutor.get());
CHECK_RET(targetCasted != nullptr, ACLNN_ERR_INNER_NULLPTR);
auto selfDimNum = self->GetViewShape().GetDimNum();
auto targetDimNum = target->GetViewShape().GetDimNum();
auto selfUnsqueeze = selfCasted;
auto targetUnsqueeze = targetCasted;
const int64_t dim = 0;
if (selfDimNum == 0) {
selfUnsqueeze = l0op::UnsqueezeNd(selfCasted, dim, uniqueExecutor.get());
CHECK_RET(selfUnsqueeze != nullptr, ACLNN_ERR_INNER_NULLPTR);
}
if (targetDimNum == 0) {
targetUnsqueeze = l0op::UnsqueezeNd(targetCasted, dim, uniqueExecutor.get());
CHECK_RET(targetUnsqueeze != nullptr, ACLNN_ERR_INNER_NULLPTR);
}
op::Shape broadcastShape;
auto selfBroadcast = selfUnsqueeze;
auto targetBroadcast = targetUnsqueeze;
OP_CHECK_BROADCAST_AND_INFER_SHAPE(selfUnsqueeze, targetUnsqueeze, broadcastShape,
return ACLNN_ERR_PARAM_INVALID);
if (reduction == None) {
if (selfDimNum == 0 && targetDimNum == 0) {
op::Shape outputShape;
OP_CHECK_SHAPE_NOT_EQUAL_WITH_EXPECTED_SIZE(out, outputShape, return ACLNN_ERR_PARAM_INVALID);
} else {
OP_CHECK_SHAPE_NOT_EQUAL_WITH_EXPECTED_SIZE(out, broadcastShape, return ACLNN_ERR_PARAM_INVALID);
}
}
op::FVector<int64_t, op::MAX_DIM_NUM> broadcastDims = op::ToShapeVector(broadcastShape);
auto broadcastShapeArray = uniqueExecutor.get()->AllocIntArray(broadcastDims.data(), broadcastDims.size());
selfBroadcast = l0op::BroadcastTo(selfUnsqueeze, broadcastShapeArray, uniqueExecutor.get());
CHECK_RET(selfBroadcast != nullptr, ACLNN_ERR_INNER_NULLPTR);
targetBroadcast = l0op::BroadcastTo(targetUnsqueeze, broadcastShapeArray, uniqueExecutor.get());
CHECK_RET(targetBroadcast != nullptr, ACLNN_ERR_INNER_NULLPTR);
auto klRes = l0op::KlDiv(selfBroadcast, targetBroadcast, GetReductionStr(reduction), logTarget, uniqueExecutor.get());
CHECK_RET(klRes != nullptr, ACLNN_ERR_INNER_NULLPTR);
CHECK_RET(CheckReduceOutShape(klRes, out), ACLNN_ERR_PARAM_INVALID);
auto klResCasted = l0op::Cast(klRes, out->GetDataType(), uniqueExecutor.get());
CHECK_RET(klResCasted != nullptr, ACLNN_ERR_INNER_NULLPTR);
auto viewCopyResult = l0op::ViewCopy(klResCasted, out, uniqueExecutor.get());
CHECK_RET(viewCopyResult != nullptr, ACLNN_ERR_INNER_NULLPTR);
*workspaceSize = uniqueExecutor->GetWorkspaceSize();
uniqueExecutor.ReleaseTo(executor);
return ACLNN_SUCCESS;
}
aclnnStatus aclnnKlDiv(void* workspace, uint64_t workspaceSize, aclOpExecutor* executor, aclrtStream stream) {
L2_DFX_PHASE_2(aclnnKlDiv);
return CommonOpExecutorRun(workspace, workspaceSize, executor, stream);
}
#ifdef __cplusplus
}
#endif