/**
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
 * This program is free software, you can redistribute it and/or modify it under the terms and conditions of
 * CANN Open Software License Agreement Version 2.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */
#include "greater_aicpu.h"

#include <complex>
#include <iostream>

#include "cpu_kernel_utils.h"
#include "kernel_util.h"
#include "log.h"
#include "status.h"

using namespace std;

namespace {
const char* const kGreater = "Greater";
const uint32_t kInputNum = 2;
const uint32_t kOutputNum = 1;
constexpr int32_t kDim1 = 1;
constexpr int32_t kDim2 = 2;
constexpr int32_t kDim3 = 3;
constexpr int32_t kDim4 = 4;
constexpr int32_t kDim5 = 5;
constexpr int32_t kDim6 = 6;
constexpr int32_t kDim7 = 7;
constexpr int32_t kDim8 = 8;
}  // namespace

namespace aicpu {
template <typename T, int32_t RANK>
uint32_t GreaterCpuKernel::BroadcastCompute(TensorMap<T> &x, TensorMap<T> &y,
                                            TensorMap<bool> &out,
                                            const Bcast &bcast) {
  Eigen::DSizes<Eigen::DenseIndex, RANK> x_reshape;
  Eigen::DSizes<Eigen::DenseIndex, RANK> y_reshape;
  Eigen::DSizes<Eigen::DenseIndex, RANK> result_shape;
  Eigen::array<Eigen::DenseIndex, RANK> x_bcast;
  Eigen::array<Eigen::DenseIndex, RANK> y_bcast;

  for (int32_t i = 0; i < RANK; i++) {
    x_reshape[i] = bcast.XReshape()[i];
    y_reshape[i] = bcast.YReshape()[i];
    result_shape[i] = bcast.ResultShape()[i];
    x_bcast[i] = bcast.XBcast()[i];
    y_bcast[i] = bcast.YBcast()[i];
  }
  out.reshape(result_shape) = x.reshape(x_reshape).broadcast(x_bcast) >
                              y.reshape(y_reshape).broadcast(y_bcast);
  return KERNEL_STATUS_OK;
}
template <typename T>
uint32_t GreaterCpuKernel::DoCompute(const CpuKernelContext &ctx) {
  auto input0_tensor = ctx.Input(kFirstInputIndex);
  auto input1_tensor = ctx.Input(kSecondInputIndex);
  DataType input0_data_type = input0_tensor->GetDataType();
  DataType input1_data_type = input1_tensor->GetDataType();
  KERNEL_CHECK_FALSE(
      (input0_data_type == input1_data_type),
      KERNEL_STATUS_PARAM_INVALID,
      "Input[x1] data type[%s] and input[x2] data type[%s] must be same",
      DTypeStr(input0_data_type).c_str(), DTypeStr(input1_data_type).c_str());
  auto input0_shape_sizes = input0_tensor->GetTensorShape()->GetDimSizes();
  auto input0_elements_num = input0_tensor->NumElements();
  TensorMap<T> input0(reinterpret_cast<T *>(input0_tensor->GetData()),
                      input0_elements_num);

  auto input1_shape_sizes = input1_tensor->GetTensorShape()->GetDimSizes();
  auto input1_elements_num = input1_tensor->NumElements();
  TensorMap<T> input1(reinterpret_cast<T *>(input1_tensor->GetData()),
                      input1_elements_num);

  auto output_tensor = ctx.Output(kFirstOutputIndex);
  auto output_elements_num = output_tensor->NumElements();
  TensorMap<bool> output(reinterpret_cast<bool *>(output_tensor->GetData()),
                         output_elements_num);

  Bcast bcast(input0_shape_sizes, input1_shape_sizes);
  if (!bcast.IsValid()) {
    KERNEL_LOG_ERROR("[%s] broadcast failed.", ctx.GetOpType().c_str());
    return KERNEL_STATUS_PARAM_INVALID;
  }
  int32_t rank = static_cast<int32_t>(bcast.XReshape().size());
  switch (rank) {
    case kDim1:
      return BroadcastCompute<T, kDim1>(input0, input1, output, bcast);
    case kDim2:
      return BroadcastCompute<T, kDim2>(input0, input1, output, bcast);
    case kDim3:
      return BroadcastCompute<T, kDim3>(input0, input1, output, bcast);
    case kDim4:
      return BroadcastCompute<T, kDim4>(input0, input1, output, bcast);
    case kDim5:
      return BroadcastCompute<T, kDim5>(input0, input1, output, bcast);
    case kDim6:
      return BroadcastCompute<T, kDim6>(input0, input1, output, bcast);
    case kDim7:
      return BroadcastCompute<T, kDim7>(input0, input1, output, bcast);
    case kDim8:
      return BroadcastCompute<T, kDim8>(input0, input1, output, bcast);
    default:
      KERNEL_LOG_ERROR("Rank[%d] broadcast Compute not support.", rank);
      return KERNEL_STATUS_PARAM_INVALID;
  }
}

uint32_t GreaterCpuKernel::Compute(CpuKernelContext &ctx) {
  KERNEL_HANDLE_ERROR(NormalCheck(ctx, kInputNum, kOutputNum),
                      "Check Greater params failed.");
  DataType input0_data_type = ctx.Input(0)->GetDataType();
  KERNEL_LOG_INFO("%s op input[x1] data type is [%s].", kGreater,
                  DTypeStr(input0_data_type).c_str());
  uint32_t ret = KERNEL_STATUS_OK;
  switch (input0_data_type) {
    case DT_FLOAT:
      ret = DoCompute<float>(ctx);
      break;
    case DT_DOUBLE:
      ret = DoCompute<double>(ctx);
      break;
    case DT_FLOAT16:
      ret = DoCompute<Eigen::half>(ctx);
      break;
    case DT_INT16:
      ret = DoCompute<int16_t>(ctx);
      break;
    case DT_INT32:
      ret = DoCompute<int32_t>(ctx);
      break;
    case DT_INT64:
      ret = DoCompute<int64_t>(ctx);
      break;
    case DT_INT8:
      ret = DoCompute<int8_t>(ctx);
      break;
    case DT_UINT16:
      ret = DoCompute<uint16_t>(ctx);
      break;
    case DT_UINT32:
      ret = DoCompute<uint32_t>(ctx);
      break;
    case DT_UINT64:
      ret = DoCompute<uint64_t>(ctx);
      break;
    case DT_UINT8:
      ret = DoCompute<uint8_t>(ctx);
      break;
    default:
      KERNEL_LOG_ERROR("Unsupported input[x1] data type[%s]",
                       DTypeStr(input0_data_type).c_str());
      ret = KERNEL_STATUS_PARAM_INVALID;
  }
  return ret;
}

REGISTER_CPU_KERNEL(kGreater, GreaterCpuKernel);
}  // namespace aicpu