/**
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
 * This program is free software, you can redistribute it and/or modify it under the terms and conditions of
 * CANN Open Software License Agreement Version 2.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

/*!
 * \file equal_aicpu.h
 * \brief
 */
#include "equal_aicpu.h"

#include "Eigen/Core"
#include "unsupported/Eigen/CXX11/Tensor"
#include "cpu_kernel_utils.h"
#include "utils/eigen_tensor.h"
#include "utils/kernel_util.h"

namespace {
constexpr uint32_t kOutputNum = 1;
constexpr uint32_t kInputNum = 2;
const char *const kEqual = "Equal";
const int64_t kParallelDataNum = 6 * 1024;
const int64_t kParallelDataNumSameShape = 7 * 1024;

#define EQUAL_COMPUTE_CASE(DTYPE, TYPE, CTX)                                    \
 case (DTYPE): {                                                               \
   uint32_t result = EqualCompute<TYPE>(CTX);                                  \
   if (result != KERNEL_STATUS_OK) {                                           \
     KERNEL_LOG_ERROR("Equal kernel compute failed , result = [%d].", result); \
     return result;                                                            \
   }                                                                           \
   break;                                                                      \
 }
}

namespace aicpu {
template <typename T>
inline void EqualImpl(T x1, T x2, bool *output) {
   *output = IsValueEqual<T>(x1, x2);
}

template <typename T>
uint32_t BcastCompute(const CpuKernelContext &ctx, const Bcast &bcast) {
   auto x1 = reinterpret_cast<T *>(ctx.Input(0)->GetData());
   auto x2 = reinterpret_cast<T *>(ctx.Input(1)->GetData());
   bool *out = reinterpret_cast<bool *>(ctx.Output(0)->GetData());
   int64_t data_num = ctx.Output(0)->NumElements();
   if (data_num >= kParallelDataNum) {
       auto sharder_equal = [&x1, &x2, &out, &bcast](int64_t start, int64_t end) {
           for (int64_t i = start; i < end; ++i) {
               EqualImpl(*(x1 + bcast.GetBroadcastXIndex(i)), *(x2 + bcast.GetBroadcastYIndex(i)), (out + i));
           }
       };
       KERNEL_HANDLE_ERROR(CpuKernelUtils::ParallelFor(ctx, data_num, 1, sharder_equal),
                           "Equal Compute failed.")
   } else {
       for (int64_t i = 0; i < data_num; ++i) {
           EqualImpl(*(x1 + bcast.GetBroadcastXIndex(i)), *(x2 + bcast.GetBroadcastYIndex(i)), (out + i));
       }
   }
   return KERNEL_STATUS_OK;
}

template <typename T>
void SpecialCompute(BcastShapeType type, int64_t start, int64_t end, const CpuKernelContext &ctx) {
   auto x1 = reinterpret_cast<T *>(ctx.Input(0)->GetData());
   auto x2 = reinterpret_cast<T *>(ctx.Input(1)->GetData());
   bool *output = reinterpret_cast<bool *>(ctx.Output(0)->GetData());
   switch (type) {
       case BcastShapeType::SAME_SHAPE:
           for (int64_t i = start; i < end; ++i) {
               EqualImpl(*(x1 + i), *(x2 + i), (output + i));
           }
           break;
       case BcastShapeType::X_ONE_ELEMENT:
           for (int64_t i = start; i < end; ++i) {
               EqualImpl(*x1, *(x2 + i), (output + i));
           }
           break;
       case BcastShapeType::Y_ONE_ELEMENT:
           for (int64_t i = start; i < end; ++i) {
               EqualImpl(*(x1 + i), *x2, (output + i));
           }
           break;
       default:
           KERNEL_LOG_WARN("Invalid type [%d]", static_cast<int32_t>(type));
   }
}

template <typename T>
uint32_t NoBcastCompute(const CpuKernelContext &ctx) {
   int64_t element_num_x1 = ctx.Input(0)->NumElements();
   int64_t element_num_x2 = ctx.Input(1)->NumElements();
   int64_t data_num = static_cast<int64_t>(ctx.Output(0)->NumElements());
   BcastShapeType type = (element_num_x1 == element_num_x2 ? BcastShapeType::SAME_SHAPE :
                                                             (element_num_x1 == 1 ? BcastShapeType::X_ONE_ELEMENT : BcastShapeType::Y_ONE_ELEMENT));
   if (data_num >= kParallelDataNumSameShape) {
       auto sharder_equal = [&type, &ctx](int64_t start, int64_t end) {
           SpecialCompute<T>(type, start, end, ctx);
       };
       KERNEL_HANDLE_ERROR(CpuKernelUtils::ParallelFor(ctx, data_num, 1, sharder_equal),
                           "Equal Compute failed.")
   } else {
       SpecialCompute<T>(type, 0, data_num, ctx);
   }
   return KERNEL_STATUS_OK;
}

template <typename T>
uint32_t EqualCompute(const CpuKernelContext &ctx) {
   Tensor *tensorx1 = ctx.Input(0);
   auto shapex1 = tensorx1->GetTensorShape()->GetDimSizes();
   Tensor *tensorx2 = ctx.Input(1);
   auto shapex2 = tensorx2->GetTensorShape()->GetDimSizes();
   Tensor *output = ctx.Output(0);
   KERNEL_LOG_INFO("CpuKernel[%s], input x1 : size[%lu], input x2: size[%lu], output: size[%lu]",
                   ctx.GetOpType().c_str(), tensorx1->GetDataSize(),
                   tensorx2->GetDataSize(), output->GetDataSize());

   bool no_need_bcast = (shapex1 == shapex2) || (tensorx1->NumElements() == 1) ||
                        (tensorx2->NumElements() == 1);
   if (no_need_bcast) {
       return NoBcastCompute<T>(ctx);
   }

   Bcast bcast(shapex1, shapex2);
   if (!bcast.IsValid()) {
       KERNEL_LOG_ERROR("[%s] broadcast failed.", ctx.GetOpType().c_str());
       return KERNEL_STATUS_PARAM_INVALID;
   }
   return BcastCompute<T>(ctx, bcast);
}

uint32_t EqualCpuKernel::Compute(CpuKernelContext &ctx) {
   // check params
   KERNEL_HANDLE_ERROR(NormalCheck(ctx, kInputNum, kOutputNum), "Check Equal params failed.");

   DataType x1_type = ctx.Input(0)->GetDataType();
   DataType x2_type = ctx.Input(1)->GetDataType();
   KERNEL_CHECK_FALSE((x1_type == x2_type), KERNEL_STATUS_PARAM_INVALID,
                      "DataType of x1 [%d] should be same as x2 [%d].",
                      static_cast<int32_t>(x1_type), static_cast<int32_t>(x2_type))
   switch (x1_type) {
       EQUAL_COMPUTE_CASE(DT_INT8, int8_t, ctx)
       EQUAL_COMPUTE_CASE(DT_INT16, int16_t, ctx)
       EQUAL_COMPUTE_CASE(DT_INT32, int32_t, ctx)
       EQUAL_COMPUTE_CASE(DT_INT64, int64_t, ctx)
       EQUAL_COMPUTE_CASE(DT_UINT8, uint8_t, ctx)
       EQUAL_COMPUTE_CASE(DT_FLOAT16, Eigen::half, ctx)
       EQUAL_COMPUTE_CASE(DT_FLOAT, float, ctx)
       EQUAL_COMPUTE_CASE(DT_DOUBLE, double, ctx)
       EQUAL_COMPUTE_CASE(DT_BOOL, bool, ctx)
       EQUAL_COMPUTE_CASE(DT_COMPLEX64, std::complex<float>, ctx)
       EQUAL_COMPUTE_CASE(DT_COMPLEX128, std::complex<double>, ctx)
       default:
           KERNEL_LOG_WARN("Equal kernel data type [%u] not support.", x1_type);
           return KERNEL_STATUS_PARAM_INVALID;
   }
   return KERNEL_STATUS_OK;
}
REGISTER_CPU_KERNEL(kEqual, EqualCpuKernel);
}  // namespace aicpu