* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "aclnn_cdist.h"
#include "cdist.h"
#include "aclnn_kernels/cast.h"
#include "aclnn_kernels/contiguous.h"
#include "aclnn_kernels/common/op_error_check.h"
#include "aclnn_kernels/reshape.h"
#include "opdev/common_types.h"
#include "opdev/data_type_utils.h"
#include "opdev/format_utils.h"
#include "opdev/op_dfx.h"
#include "opdev/op_executor.h"
#include "opdev/op_log.h"
#include "opdev/shape_utils.h"
#include "opdev/tensor_view_utils.h"
#include "opdev/platform.h"
#include "conversion/fill/op_api/fill.h"
#include "conversion/unsqueeze/op_host/op_api/unsqueeze.h"
#include "conversion/broadcast_to/op_api/broadcast_to.h"
#include "op_api/aclnn_check.h"
using namespace op;
#ifdef __cplusplus
extern "C" {
#endif
constexpr size_t MAX_DIM_LEN = 8;
constexpr size_t MIN_DIM_LEN = 2;
static const std::initializer_list<op::DataType> DTYPE_SUPPORT_LIST = {
op::DataType::DT_FLOAT, op::DataType::DT_FLOAT16, op::DataType::DT_BF16};
static const inline std::initializer_list<DataType>& GetSupportDtypeList(SocVersion socVersion)
{
static const std::initializer_list<DataType> emptyDtypes = {};
if (socVersion == SocVersion::ASCEND910B || socVersion == SocVersion::ASCEND910_93 || IsRegBase()) {
return DTYPE_SUPPORT_LIST;
} else {
return emptyDtypes;
}
}
static bool CheckNotNull(const aclTensor* x1, const aclTensor* x2, const aclTensor* out)
{
OP_CHECK_NULL(x1, return false);
OP_CHECK_NULL(x2, return false);
OP_CHECK_NULL(out, return false);
return true;
}
static bool CheckDtypeValid(const aclTensor* x1, const aclTensor* x2, const aclTensor* out, SocVersion socVersion)
{
static const std::initializer_list<DataType> dtypeSupportList = GetSupportDtypeList(socVersion);
OP_CHECK_DTYPE_NOT_SUPPORT(x1, dtypeSupportList, return false);
OP_CHECK_DTYPE_NOT_SUPPORT(x2, dtypeSupportList, return false);
OP_CHECK_DTYPE_NOT_SUPPORT(out, dtypeSupportList, return false);
OP_CHECK_DTYPE_NOT_SAME(x1, out, return false);
OP_CHECK_DTYPE_NOT_SAME(x2, out, return false);
return true;
}
static bool CheckShape(const aclTensor* x1, const aclTensor* x2, const aclTensor* out)
{
OP_CHECK_MAX_DIM(x1, MAX_DIM_LEN, return false);
OP_CHECK_MAX_DIM(x2, MAX_DIM_LEN, return false);
OP_CHECK_MAX_DIM(out, MAX_DIM_LEN, return false);
OP_CHECK_MIN_DIM(x1, MIN_DIM_LEN, return false);
OP_CHECK_MIN_DIM(x2, MIN_DIM_LEN, return false);
OP_CHECK_MIN_DIM(out, MIN_DIM_LEN, return false);
auto x1DimNum = x1->GetViewShape().GetDimNum();
auto x2DimNum = x2->GetViewShape().GetDimNum();
int64_t x1LastDim = static_cast<int64_t>(x1->GetViewShape().GetDim(x1DimNum - 1));
int64_t x2LastDim = static_cast<int64_t>(x2->GetViewShape().GetDim(x2DimNum - 1));
if(x1LastDim != x2LastDim) {
OP_LOGE(ACLNN_ERR_PARAM_INVALID, "The dim num of X1 and X2 must be the same. X1 got: %ld, X2 got: %ld.", x1LastDim, x2LastDim);
return false;
}
return true;
}
static bool CheckParamsLogic(float& p, const int64_t compute_mode, SocVersion socVersion)
{
if(p < 0 || std::isnan(p)) {
OP_LOGE(ACLNN_ERR_PARAM_INVALID, "cdist only supports non-negative p values.");
return false;
}
if (std::fabs(p - 2.0) < 1e-6 && compute_mode != 2){
OP_LOGE(ACLNN_ERR_PARAM_INVALID, "compute_mode %lld is not supported, only support 2 when pValue equals 2.0",
compute_mode);
return false;
}
if ((socVersion == SocVersion::ASCEND910B || socVersion == SocVersion::ASCEND910_93) && std::isinf(p)){
p = -1.0;
}
return true;
}
static void CheckFormat(const aclTensor* x1, const aclTensor* x2)
{
ge::Format x1StorageFormat = x1->GetStorageFormat();
ge::Format x2StorageFormat = x2->GetStorageFormat();
if (x1StorageFormat != ge::Format::FORMAT_ND || x2StorageFormat != ge::Format::FORMAT_ND) {
OP_LOGW("aclnnCdist only support ND format.");
}
}
static aclnnStatus CheckParams(
const aclTensor* x1, const aclTensor* x2, float& p, const int64_t compute_mode, aclTensor* out,
SocVersion socVersion)
{
CHECK_COND(CheckNotNull(x1, x2, out), ACLNN_ERR_PARAM_NULLPTR, "CheckNotNull failed!");
CHECK_COND(CheckDtypeValid(x1, x2, out, socVersion), ACLNN_ERR_PARAM_INVALID, "CheckDtypeValid failed!");
CHECK_COND(CheckShape(x1, x2, out), ACLNN_ERR_PARAM_INVALID, "CheckShape failed!");
CHECK_COND(CheckParamsLogic(p, compute_mode, socVersion), ACLNN_ERR_PARAM_INVALID, "CheckParamsLogic failed!");
CheckFormat(x1, x2);
return ACLNN_SUCCESS;
}
static bool CheckBatchDimNum(const aclTensor* x) {
auto xDimNum = x->GetViewShape().GetDimNum();
for(int64_t i = 0; i < static_cast<int64_t>(xDimNum - 1); i++) {
if(x->GetViewShape().GetDim(i) == 0) {
return true;
}
}
return false;
}
aclnnStatus aclnnCdistGetWorkspaceSize(
const aclTensor* x1, const aclTensor* x2, float p, int64_t compute_mode, aclTensor* out, uint64_t* workspaceSize,
aclOpExecutor** executor)
{
OP_CHECK_COMM_INPUT(workspaceSize, executor);
L2_DFX_PHASE_1(aclnnCdist, DFX_IN(x1, x2, p, compute_mode), DFX_OUT(out));
SocVersion socVersion = GetCurrentPlatformInfo().GetSocVersion();
auto uniqueExecutor = CREATE_EXECUTOR();
CHECK_RET(uniqueExecutor.get() != nullptr, ACLNN_ERR_INNER_CREATE_EXECUTOR);
auto ret = CheckParams(x1, x2, p, compute_mode, out, socVersion);
CHECK_RET(ret == ACLNN_SUCCESS, ret);
const aclTensor* CdistOutRet = nullptr;
auto x1DimNum = x1->GetViewShape().GetDimNum();
auto x2DimNum = x2->GetViewShape().GetDimNum();
if(CheckBatchDimNum(x1) || CheckBatchDimNum(x2)) {
*workspaceSize = 0;
uniqueExecutor.ReleaseTo(executor);
return ACLNN_SUCCESS;
}
bool needCast = x1->GetDataType() == op::DataType::DT_BF16 && (socVersion == SocVersion::ASCEND910B || socVersion == SocVersion::ASCEND910_93);
auto x1Contiguous = l0op::Contiguous(x1, uniqueExecutor.get());
CHECK_RET(x1Contiguous != nullptr, ACLNN_ERR_INNER_NULLPTR);
auto x2Contiguous = l0op::Contiguous(x2, uniqueExecutor.get());
CHECK_RET(x2Contiguous != nullptr, ACLNN_ERR_INNER_NULLPTR);
op::Shape x1BroadcastShape;
op::Shape x2BroadcastShape;
int64_t x1Rank = x1Contiguous->GetViewShape().GetDimNum();
int64_t x2Rank = x2Contiguous->GetViewShape().GetDimNum();
int64_t minRank = std::min(x1Rank, x2Rank);
int64_t maxRank = std::max(x1Rank, x2Rank);
for (int64_t i = 0; i < static_cast<int64_t>(maxRank - MIN_DIM_LEN); i++) {
int64_t dim1 = (x1Rank < maxRank) ?
((i < maxRank - minRank) ? 1 : x1Contiguous->GetViewShape().GetDim(i - maxRank + minRank)) :
x1Contiguous->GetViewShape().GetDim(i);
int64_t dim2 = (x2Rank < maxRank) ?
((i < maxRank - minRank) ? 1 : x2Contiguous->GetViewShape().GetDim(i - maxRank + minRank)) :
x2Contiguous->GetViewShape().GetDim(i);
if(dim1 != dim2) {
if(dim1 != 1 && dim2 != 1) {
OP_LOGE(ACLNN_ERR_PARAM_INVALID, "X1 shape: %s and X2 shape: %s cannot broadcast!",
op::ToString(x1->GetViewShape()).GetString(), op::ToString(x2->GetViewShape()).GetString());
return ACLNN_ERR_PARAM_INVALID;
}
}
int64_t broadcastDim = std::max(dim1, dim2);
x1BroadcastShape.AppendDim(broadcastDim);
x2BroadcastShape.AppendDim(broadcastDim);
}
for (int64_t i = 1; i >= 0; i--) {
x1BroadcastShape.AppendDim(x1Contiguous->GetViewShape().GetDim(x1DimNum - i - 1));
x2BroadcastShape.AppendDim(x2Contiguous->GetViewShape().GetDim(x2DimNum - i - 1));
}
if(x1Contiguous->GetViewShape().GetDim(x1DimNum - 1) == 0) {
aclScalar* scalar = uniqueExecutor.get()->AllocScalar(0);
auto valueTensor = uniqueExecutor.get()->ConvertToTensor(scalar, out->GetDataType());
op::Shape outShape;
for(int64_t i = 0; i < static_cast<int64_t>(maxRank - 1); i++) {
int64_t dim = x1BroadcastShape.GetDim(i);
outShape.AppendDim(dim);
}
outShape.AppendDim(x2BroadcastShape.GetDim(maxRank - MIN_DIM_LEN));
auto outputDims = op::ToShapeVector(outShape);
aclIntArray* dimArray = uniqueExecutor.get()->AllocIntArray(outputDims.data(), outputDims.size());
CHECK_RET(dimArray != nullptr, ACLNN_ERR_INNER_NULLPTR);
auto dimTensor = uniqueExecutor.get()->ConvertToTensor(dimArray, op::DataType::DT_INT64);
CdistOutRet = l0op::Fill(dimTensor, valueTensor, dimArray, uniqueExecutor.get());
} else {
op::FVector<int64_t, op::MAX_DIM_NUM> x1BroadcastDims = op::ToShapeVector(x1BroadcastShape);
auto x1BroadcastShapeArray = uniqueExecutor.get()->AllocIntArray(x1BroadcastDims.data(), x1BroadcastDims.size());
op::FVector<int64_t, op::MAX_DIM_NUM> x2BroadcastDims = op::ToShapeVector(x2BroadcastShape);
auto x2BroadcastShapeArray = uniqueExecutor.get()->AllocIntArray(x2BroadcastDims.data(), x2BroadcastDims.size());
auto x1Broadcast = l0op::BroadcastTo(x1Contiguous, x1BroadcastShapeArray, uniqueExecutor.get());
CHECK_RET(x1Broadcast != nullptr, ACLNN_ERR_INNER_NULLPTR);
auto x2Broadcast = l0op::BroadcastTo(x2Contiguous, x2BroadcastShapeArray, uniqueExecutor.get());
CHECK_RET(x2Broadcast != nullptr, ACLNN_ERR_INNER_NULLPTR);
if (socVersion == SocVersion::ASCEND910B || socVersion == SocVersion::ASCEND910_93) {
if (needCast) {
x1Broadcast = l0op::Cast(x1Broadcast, op::DataType::DT_FLOAT, uniqueExecutor.get());
CHECK_RET(x1Broadcast != nullptr, ACLNN_ERR_INNER_NULLPTR);
x2Broadcast = l0op::Cast(x2Broadcast, op::DataType::DT_FLOAT, uniqueExecutor.get());
CHECK_RET(x2Broadcast != nullptr, ACLNN_ERR_INNER_NULLPTR);
}
gert::Shape x1Shape = x1Broadcast->GetViewShape();
gert::Shape x2Shape = x2Broadcast->GetViewShape();
size_t dimNum = x1Shape.GetDimNum();
int64_t tensorR = x2Shape.GetDim(dimNum - 2);
int64_t tensorM = x1Shape.GetDim(dimNum - 1);
gert::Shape x1UnsqueezedShape = x1Shape;
gert::Shape x2UnsqueezedShape = x2Shape;
x1UnsqueezedShape.AppendDim(tensorM);
x2UnsqueezedShape.AppendDim(tensorM);
x1UnsqueezedShape.SetDim(dimNum - 1, 1);
x2UnsqueezedShape.SetDim(dimNum - 2, 1);
x2UnsqueezedShape.SetDim(dimNum - 1, tensorR);
auto x1Reshape = l0op::Reshape(x1Broadcast, x1UnsqueezedShape, uniqueExecutor.get());
auto x2Reshape = l0op::Reshape(x2Broadcast, x2UnsqueezedShape, uniqueExecutor.get());
gert::Shape finalInputShape = x1UnsqueezedShape;
finalInputShape.SetDim(dimNum - 1, x2UnsqueezedShape[dimNum - 1]);
FVector<int64_t, MAX_DIM_NUM> broadcastDims = ToShapeVector(finalInputShape);
aclIntArray* inputBroadcastShapeArray =
uniqueExecutor.get()->AllocIntArray(broadcastDims.data(), broadcastDims.size());
CHECK_RET(inputBroadcastShapeArray != nullptr, ACLNN_ERR_INNER_NULLPTR);
x1Broadcast = l0op::BroadcastTo(x1Reshape, inputBroadcastShapeArray, uniqueExecutor.get());
x2Broadcast = l0op::BroadcastTo(x2Reshape, inputBroadcastShapeArray, uniqueExecutor.get());
}
CdistOutRet = l0op::Cdist(x1Broadcast, x2Broadcast, p, compute_mode, uniqueExecutor.get());
if (needCast) {
CdistOutRet = l0op::Cast(CdistOutRet, op::DataType::DT_BF16, uniqueExecutor.get());
CHECK_RET(CdistOutRet != nullptr, ACLNN_ERR_INNER_NULLPTR);
}
}
CHECK_RET(CdistOutRet != nullptr, ACLNN_ERR_INNER_NULLPTR);
auto viewCopyResult = l0op::ViewCopy(CdistOutRet, out, uniqueExecutor.get());
CHECK_RET(viewCopyResult != nullptr, ACLNN_ERR_INNER_NULLPTR);
*workspaceSize = uniqueExecutor->GetWorkspaceSize();
uniqueExecutor.ReleaseTo(executor);
return ACLNN_SUCCESS;
}
aclnnStatus aclnnCdist(void* workspace, uint64_t workspaceSize, aclOpExecutor* executor, aclrtStream stream)
{
L2_DFX_PHASE_2(aclnnCdist);
return CommonOpExecutorRun(workspace, workspaceSize, executor, stream);
}
#ifdef __cplusplus
}
#endif