* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include <iostream>
#include <cstdint>
#include "not_equal_aicpu.h"
#include "cpu_kernel_utils.h"
#include "utils/eigen_tensor.h"
#include "utils/kernel_util.h"
namespace {
constexpr uint32_t kOutputNum = 1;
constexpr uint32_t kInputNum = 2;
const char *const kNotEqual = "NotEqual";
const int64_t kParallelDataNum = INT64_C(6) * INT64_C(1024);
const int64_t kParallelDataNumSameShape = INT64_C(7) * INT64_C(1024);
}
namespace aicpu {
template <typename T>
inline void NotEqualImpl(T a, T b, bool *output) {
*output = !IsValueEqual<T>(a, b);
}
template <typename T>
uint32_t NotEqualBcastCompute(const CpuKernelContext &ctx, const Bcast &bcast) {
auto x1 = reinterpret_cast<T *>(ctx.Input(0)->GetData());
auto x2 = reinterpret_cast<T *>(ctx.Input(1)->GetData());
bool *out = reinterpret_cast<bool *>(ctx.Output(0)->GetData());
int64_t data_num = ctx.Output(0)->NumElements();
if (data_num >= kParallelDataNum) {
auto sharder_not_equal = [&x1, &x2, &out, &bcast](int64_t start, int64_t end) {
for (int64_t i = start; i < end; ++i) {
NotEqualImpl(*(x1 + bcast.GetBroadcastXIndex(i)), *(x2 + bcast.GetBroadcastYIndex(i)), (out + i));
}
};
KERNEL_HANDLE_ERROR(CpuKernelUtils::ParallelFor(ctx, data_num, 1, sharder_not_equal),
"Equal Compute failed.")
} else {
for (int64_t i = 0; i < data_num; ++i) {
NotEqualImpl(*(x1 + bcast.GetBroadcastXIndex(i)), *(x2 + bcast.GetBroadcastYIndex(i)), (out + i));
}
}
return KERNEL_STATUS_OK;
}
template <typename T>
void NotEqualSpecialCompute(BcastShapeType type, int64_t start, int64_t end, const CpuKernelContext &ctx) {
auto x1 = reinterpret_cast<T *>(ctx.Input(0)->GetData());
auto x2 = reinterpret_cast<T *>(ctx.Input(1)->GetData());
bool *output = reinterpret_cast<bool *>(ctx.Output(0)->GetData());
switch (type) {
case BcastShapeType::SAME_SHAPE:
for (int64_t i = start; i < end; ++i) {
NotEqualImpl(*(x1 + i), *(x2 + i), (output + i));
}
break;
case BcastShapeType::X_ONE_ELEMENT:
for (int64_t i = start; i < end; ++i) {
NotEqualImpl(*x1, *(x2 + i), (output + i));
}
break;
case BcastShapeType::Y_ONE_ELEMENT:
for (int64_t i = start; i < end; ++i) {
NotEqualImpl(*(x1 + i), *x2, (output + i));
}
break;
default:
KERNEL_LOG_WARN("Invalid type [%d]", static_cast<int32_t>(type));
}
}
template <typename T>
uint32_t NotEqualNoBcastCompute(const CpuKernelContext &ctx) {
int64_t element_num_x1 = ctx.Input(0)->NumElements();
int64_t element_num_x2 = ctx.Input(1)->NumElements();
int64_t data_num = static_cast<int64_t>(ctx.Output(0)->NumElements());
BcastShapeType type = (element_num_x1 == element_num_x2 ? BcastShapeType::SAME_SHAPE :
(element_num_x1 == 1 ? BcastShapeType::X_ONE_ELEMENT : BcastShapeType::Y_ONE_ELEMENT));
if (data_num >= kParallelDataNumSameShape) {
auto sharder_not_equal = [&type, &ctx](int64_t start, int64_t end) {
NotEqualSpecialCompute<T>(type, start, end, ctx);
};
KERNEL_HANDLE_ERROR(CpuKernelUtils::ParallelFor(ctx, data_num, 1, sharder_not_equal),
"Equal Compute failed.")
} else {
NotEqualSpecialCompute<T>(type, 0, data_num, ctx);
}
return KERNEL_STATUS_OK;
}
template <typename T>
uint32_t NotEqualCompute(const CpuKernelContext &ctx) {
Tensor *tensorx1 = ctx.Input(0);
auto shapex1 = tensorx1->GetTensorShape()->GetDimSizes();
Tensor *tensorx2 = ctx.Input(1);
auto shapex2 = tensorx2->GetTensorShape()->GetDimSizes();
Tensor *output = ctx.Output(0);
KERNEL_LOG_INFO("CpuKernel[%s], input x1 : size[%lu], input x2: size[%lu], output: size[%lu]",
ctx.GetOpType().c_str(), tensorx1->GetDataSize(),
tensorx2->GetDataSize(), output->GetDataSize());
bool no_need_bcast = (shapex1 == shapex2) || (tensorx1->NumElements() == 1) ||
(tensorx2->NumElements() == 1);
if (no_need_bcast) {
return NotEqualNoBcastCompute<T>(ctx);
}
Bcast bcast(shapex1, shapex2);
if (!bcast.IsValid()) {
KERNEL_LOG_ERROR("[%s] broadcast failed.", ctx.GetOpType().c_str());
return KERNEL_STATUS_PARAM_INVALID;
}
return NotEqualBcastCompute<T>(ctx, bcast);
}
template <typename T>
uint32_t NotEqualComputeCase(const CpuKernelContext &ctx) {
uint32_t result = NotEqualCompute<T>(ctx);
if (result != KERNEL_STATUS_OK) {
KERNEL_LOG_ERROR("NotEqual kernel compute failed, result = [%d].", result);
return result;
}
return KERNEL_STATUS_OK;
}
uint32_t NotEqualCpuKernel::Compute(CpuKernelContext &ctx) {
KERNEL_HANDLE_ERROR(NormalCheck(ctx, kInputNum, kOutputNum),
"Check NotEqual params failed.");
DataType x1_type = ctx.Input(0)->GetDataType();
DataType x2_type = ctx.Input(1)->GetDataType();
KERNEL_CHECK_FALSE((x1_type == x2_type), KERNEL_STATUS_PARAM_INVALID,
"DataType of x1 [%d] should be same as x2 [%d].",
static_cast<int32_t>(x1_type), static_cast<int32_t>(x2_type))
switch (x1_type) {
case DT_INT8:
return NotEqualComputeCase<int8_t>(ctx);
case DT_INT16:
return NotEqualComputeCase<int16_t>(ctx);
case DT_INT32:
return NotEqualComputeCase<int32_t>(ctx);
case DT_INT64:
return NotEqualComputeCase<int64_t>(ctx);
case DT_UINT8:
return NotEqualComputeCase<uint8_t>(ctx);
case DT_FLOAT16:
return NotEqualComputeCase<Eigen::half>(ctx);
case DT_FLOAT:
return NotEqualComputeCase<float>(ctx);
case DT_DOUBLE:
return NotEqualComputeCase<double>(ctx);
case DT_BOOL:
return NotEqualComputeCase<bool>(ctx);
default:
KERNEL_LOG_WARN("NotEqual kernel data type [%u] not support.", x1_type);
return KERNEL_STATUS_PARAM_INVALID;
}
return KERNEL_STATUS_OK;
}
REGISTER_CPU_KERNEL(kNotEqual, NotEqualCpuKernel);
}