/**
 * Copyright (c) 2026 Huawei Technologies Co., Ltd.
 * This program is free software, you can redistribute it and/or modify it under the terms and conditions of
 * CANN Open Software License Agreement Version 2.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

#include "lin_space_aicpu.h"

#include <iostream>

#include "cpu_kernel_utils.h"
#include "securec.h"
#include "utils/eigen_tensor.h"
#include "utils/kernel_util.h"

namespace {
const char *const kLinSpace = "LinSpace";
const uint32_t kInputNum = 3;
const uint32_t kOutputNum = 1;
}

namespace aicpu {
uint32_t LinSpaceParaCheck(const CpuKernelContext &ctx, int64_t &num_value) {
  Tensor *tensor_start = ctx.Input(kFirstInputIndex);
  Tensor *tensor_stop = ctx.Input(kSecondInputIndex);
  Tensor *tensor_num = ctx.Input(kThirdInputIndex);
  Tensor *tensor_output = ctx.Output(kFirstOutputIndex);

  auto start_shape = tensor_start->GetTensorShape();
  KERNEL_CHECK_FALSE((IsScalar(start_shape->GetDimSizes()) ||
                     ((start_shape->GetDimSizes().size() == 1) &&
                     (start_shape->GetDimSize(0) == 1))), KERNEL_STATUS_PARAM_INVALID,
                     "Input[start] must be a scalar")
  auto stop_shape = tensor_stop->GetTensorShape();
  KERNEL_CHECK_FALSE((IsScalar(stop_shape->GetDimSizes()) ||
                     ((stop_shape->GetDimSizes().size() == 1) &&
                     (stop_shape->GetDimSize(0) == 1))), KERNEL_STATUS_PARAM_INVALID,
                     "Input[stop] must be a scalar")
  auto num_shape = tensor_num->GetTensorShape();
  KERNEL_CHECK_FALSE((IsScalar(num_shape->GetDimSizes()) ||
                     ((num_shape->GetDimSizes().size() == 1) &&
                     (num_shape->GetDimSize(0) == 1))), KERNEL_STATUS_PARAM_INVALID,
                     "Input[num] must be a scalar")
  KERNEL_CHECK_FALSE((tensor_start->GetDataType() == tensor_stop->GetDataType()), KERNEL_STATUS_PARAM_INVALID,
                     "start datatype != stop datatype fail.")
  KERNEL_CHECK_FALSE((tensor_start->GetDataType() == tensor_output->GetDataType()), KERNEL_STATUS_PARAM_INVALID,
                     "start datatype != output datatype fail.")

  auto num_type = static_cast<DataType>(tensor_num->GetDataType());
  switch (num_type) {
    case DT_INT32:
    {
      int32_t *num32 = reinterpret_cast<int32_t *>(tensor_num->GetData());
      num_value = static_cast<int64_t>(*num32);
      break;
    }
    case DT_INT64:
    {
      int64_t *num64 = reinterpret_cast<int64_t *>(tensor_num->GetData());
      num_value = *num64;
      break;
    }
    default:
      KERNEL_LOG_ERROR("num datatype[%d] must be DT_INT32 or DT_INT64 fail.", num_type);
      return KERNEL_STATUS_PARAM_INVALID;
  }
  KERNEL_CHECK_FALSE((num_value > 0), KERNEL_STATUS_PARAM_INVALID, "Input[num] <= 0 fail.")
  return KERNEL_STATUS_OK;
}

template <typename T>
uint32_t LinSpaceCompute(const CpuKernelContext &ctx, int64_t num_value) {
  T *start_value = reinterpret_cast<T *>(ctx.Input(kFirstInputIndex)->GetData());
  T *stop_value = reinterpret_cast<T *>(ctx.Input(kSecondInputIndex)->GetData());
  T *output_value = reinterpret_cast<T *>(ctx.Output(kFirstOutputIndex)->GetData());

  output_value[0] = *start_value;
  if (num_value > 1) {
      T interval = (*stop_value - *start_value) / (num_value - 1);
      for (int64_t i = 1; i < num_value - 1; i++) {
        output_value[i] = *start_value + interval * i;
      }
      output_value[num_value - 1] = *stop_value;
  }

  return KERNEL_STATUS_OK;
}

uint32_t LinSpaceCpuKernel::Compute(CpuKernelContext &ctx) {
  int64_t num_value = 0;
  KERNEL_HANDLE_ERROR(NormalCheck(ctx, kInputNum, kOutputNum), "LinSpace NormalCheck fail.");
  KERNEL_HANDLE_ERROR(LinSpaceParaCheck(ctx, num_value), "LinSpace LinSpaceParaCheck fail.");

  auto data_type = static_cast<DataType>(ctx.Input(kFirstInputIndex)->GetDataType());
  switch (data_type) {
    case DT_FLOAT:
      return LinSpaceCompute<float>(ctx, num_value);
    case DT_DOUBLE:
      return LinSpaceCompute<double>(ctx, num_value);
    default:
      KERNEL_LOG_ERROR("LinSpace dtype[%d] is invalid.", data_type);
      return KERNEL_STATUS_PARAM_INVALID;
  }
}

REGISTER_CPU_KERNEL(kLinSpace, LinSpaceCpuKernel);
}  // namespace aicpu