* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "cdist.h"
#include "opdev/make_op_executor.h"
#include "opdev/op_def.h"
#include "opdev/op_dfx.h"
#include "opdev/op_executor.h"
#include "opdev/op_log.h"
#include "opdev/shape_utils.h"
#include "opdev/platform.h"
#include "op_api/aclnn_check.h"
using namespace op;
namespace l0op {
constexpr size_t MIN_DIM_LEN = 2;
static const std::initializer_list<DataType> ASCEND910B_AICORE_DTYPE_SUPPORT_LIST = {
DataType::DT_FLOAT, DataType::DT_FLOAT16};
static const std::initializer_list<DataType> ASCEND950_AICORE_DTYPE_SUPPORT_LIST = {
DataType::DT_FLOAT, DataType::DT_FLOAT16, DataType::DT_BF16};
OP_TYPE_REGISTER(Cdist);
const Shape InferShapeForA2A3(const aclTensor* x1)
{
Shape x1Shape = x1->GetViewShape();
size_t dimNum = x1Shape.GetDimNum();
Shape outShape;
for (size_t i = 0; i < dimNum - 1; i++) {
outShape.AppendDim(x1Shape.GetDim(i));
}
return outShape;
}
const Shape InferShapeForA5(const aclTensor* x1, const aclTensor* x2)
{
op::Shape outShape;
int64_t x1DimNum = x1->GetViewShape().GetDimNum();
int64_t x2DimNum = x2->GetViewShape().GetDimNum();
for(int64_t i = 0; i < x1DimNum - 1; i++) {
int64_t dim = x1->GetViewShape().GetDim(i);
outShape.AppendDim(dim);
}
outShape.AppendDim(x2->GetViewShape().GetDim(x2DimNum - MIN_DIM_LEN));
return outShape;
}
static inline const aclTensor* CdistAiCore(
const aclTensor* x1, const aclTensor* x2, float p, aclTensor* output, aclOpExecutor* executor)
{
L0_DFX(CdistAiCore, x1, x2, p, output);
auto ret = ADD_TO_LAUNCHER_LIST_AICORE(Cdist, OP_INPUT(x1, x2), OP_OUTPUT(output), OP_ATTR(p));
OP_CHECK(
ret == ACLNN_SUCCESS, OP_LOGE(ACLNN_ERR_INNER_NULLPTR, "CdistAiCore ADD_TO_LAUNCHER_LIST_AICORE failed."),
return nullptr);
return output;
}
static inline const aclTensor* CdistAiCore(
const aclTensor* x1, const aclTensor* x2, float p, int64_t compute_mode, aclTensor* output, aclOpExecutor* executor)
{
L0_DFX(CdistAiCore, x1, x2, p, compute_mode, output);
auto ret = ADD_TO_LAUNCHER_LIST_AICORE(Cdist, OP_INPUT(x1, x2), OP_ATTR(p, compute_mode), OP_OUTPUT(output));
OP_CHECK(
ret == ACLNN_SUCCESS, OP_LOGE(ACLNN_ERR_INNER_NULLPTR, "CdistAiCore ADD_TO_LAUNCHER_LIST_AICORE failed."),
return nullptr);
return output;
}
const aclTensor* Cdist(const aclTensor *x1, const aclTensor *x2, float p, int64_t compute_mode,
aclOpExecutor *executor) {
SocVersion socVersion = GetCurrentPlatformInfo().GetSocVersion();
Shape outShape;
if (socVersion == SocVersion::ASCEND910B || socVersion == SocVersion::ASCEND910_93) {
outShape = InferShapeForA2A3(x1);
} else if (IsRegBase()) {
outShape = InferShapeForA5(x1, x2);
}
auto out = executor->AllocTensor(outShape, x1->GetDataType(), op::Format::FORMAT_ND);
CHECK_RET(out != nullptr, nullptr);
if (socVersion == SocVersion::ASCEND910B || socVersion == SocVersion::ASCEND910_93) {
if (CheckType(x1->GetDataType(), ASCEND910B_AICORE_DTYPE_SUPPORT_LIST) &&
CheckType(x2->GetDataType(), ASCEND910B_AICORE_DTYPE_SUPPORT_LIST)) {
return CdistAiCore(x1, x2, p, out, executor);
}
} else if (IsRegBase()) {
if (CheckType(x1->GetDataType(), ASCEND950_AICORE_DTYPE_SUPPORT_LIST) &&
CheckType(x2->GetDataType(), ASCEND950_AICORE_DTYPE_SUPPORT_LIST)) {
INFER_SHAPE(Cdist, OP_INPUT(x1, x2), OP_OUTPUT(out), OP_ATTR(p, compute_mode));
return CdistAiCore(x1, x2, p, compute_mode, out, executor);
}
} else {
OP_LOGE(ACLNN_ERR_INNER_NULLPTR, "Data type not supported.");
return nullptr;
}
return nullptr;
}
}