/**
 * Copyright (c) 2026 Huawei Technologies Co., Ltd.
 * This program is free software, you can redistribute it and/or modify it under the terms and conditions of
 * CANN Open Software License Agreement Version 2.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

/* !
 * \file unpack_aicpu.cpp
 * \brief
 */
#include "unpack_aicpu.h"
#include "utils/kernel_util.h"

namespace {
const char *kUnpack = "Unpack";
}

namespace aicpu {
uint32_t UnpackCpuKernel::CheckAndInitParams(CpuKernelContext &ctx) {
  Tensor *value_ptr = ctx.Input(0);
  KERNEL_CHECK_NULLPTR(value_ptr, KERNEL_STATUS_PARAM_INVALID,
                       "Get input value failed.");
  value_data_ptr = value_ptr->GetData();
  KERNEL_CHECK_NULLPTR(value_data_ptr, KERNEL_STATUS_PARAM_INVALID,
                       "Get input value data failed.");
  auto value_shape_ptr = value_ptr->GetTensorShape();
  KERNEL_CHECK_NULLPTR(value_shape_ptr, KERNEL_STATUS_PARAM_INVALID,
                       "Get input value shape failed.");
  int64_t value_dim = value_shape_ptr->GetDims();

  AttrValue *unpack_axis_ptr = ctx.GetAttr("axis");
  int64_t real_unpack_axis = 0;
  KERNEL_CHECK_FALSE(unpack_axis_ptr, KERNEL_STATUS_PARAM_INVALID,
                     "get axis failed!");
  unpack_axis = unpack_axis_ptr->GetInt();
  real_unpack_axis = unpack_axis >= 0 ? unpack_axis : unpack_axis + value_dim;
  KERNEL_CHECK_FALSE(value_dim > real_unpack_axis, KERNEL_STATUS_PARAM_INVALID,
                     "The axis value range should be [-value_dim, value_dim), "
                     "value dim is [%ld], axis is [%llu].",
                     value_dim, unpack_axis);
  unpack_axis = real_unpack_axis;

  AttrValue *unpack_num_ptr = ctx.GetAttr("num");
  KERNEL_CHECK_FALSE(unpack_num_ptr, KERNEL_STATUS_PARAM_INVALID,
                     "get num failed!");
  int64_t axis_size = value_shape_ptr->GetDimSize(unpack_axis);
  unpack_num = unpack_num_ptr->GetInt();
  KERNEL_CHECK_FALSE(unpack_num == axis_size, KERNEL_STATUS_PARAM_INVALID,
                     "The num you want to unpack to should be equal to the "
                     "size of the specified dimension. "
                     "The num you want to unpack to is [%ld], while the [%llu] "
                     "dim's size is [%ld].",
                     unpack_num, unpack_axis, axis_size);
  value_shape_vec = value_shape_ptr->GetDimSizes();
  data_type = value_ptr->GetDataType();
  value_num = value_ptr->NumElements();

  output_ptr_vec.resize(unpack_num);
  for (int64_t i = 0; i < unpack_num; i++) {
    Tensor *output_ptr = ctx.Output(i);
    KERNEL_CHECK_NULLPTR(output_ptr, KERNEL_STATUS_PARAM_INVALID,
                         "Get output [%ld] failed.", i);
    auto output_data_ptr = output_ptr->GetData();
    KERNEL_CHECK_NULLPTR(output_data_ptr, KERNEL_STATUS_PARAM_INVALID,
                         "Get output data [%ld] failed.", i);
    output_ptr_vec[i] = output_data_ptr;
  }
  return KERNEL_STATUS_OK;
}

template <typename T>
uint32_t UnpackCpuKernel::UnpackWithOneOutput(
    T *input_data_ptr, std::vector<T *> output_data_vec) {
  int64_t copy_size = value_num * sizeof(T);
  auto mem_ret =
      memcpy_s(output_data_vec[0], copy_size, input_data_ptr, copy_size);
  KERNEL_CHECK_FALSE((mem_ret == EOK), KERNEL_STATUS_PARAM_INVALID,
                     "Memcpy size[%zu] from input value to output[0] failed.",
                     copy_size);
  return KERNEL_STATUS_OK;
}

template <typename T>
uint32_t UnpackCpuKernel::UnpackWithDimZero(T *input_data_ptr,
                                            std::vector<T *> output_data_vec) {
  if (value_shape_vec[0] == 0) {
    KERNEL_CHECK_FALSE(value_shape_vec[0] > 0, KERNEL_STATUS_PARAM_INVALID,
                       "The shape of input tensor is invalid.");
  }
  int64_t copy_num = value_num / value_shape_vec[0];
  T *input_copy_ptr = input_data_ptr;
  for (int64_t i = 0; i < unpack_num; i++) {
    int64_t copy_size_per = copy_num;
    auto copy_size = copy_size_per * sizeof(T);
    auto mem_ret =
        memcpy_s(output_data_vec[i], copy_size, input_copy_ptr, copy_size);
    KERNEL_CHECK_FALSE(
        (mem_ret == EOK), KERNEL_STATUS_PARAM_INVALID,
        "Memcpy size[%zu] from input value to output[%ld] failed.", copy_size,
        i);
    input_copy_ptr += copy_size_per;
  }
  return KERNEL_STATUS_OK;
}

template <typename T>
uint32_t UnpackCpuKernel::UnpackCompute(T *input_data_ptr,
                                        std::vector<T *> output_data_vec,
                                        CpuKernelContext &ctx) {
  int64_t prefix = 1;
  for (uint64_t i = 0; i < unpack_axis; i++) {
    if (value_shape_vec[i] == 0) {
      KERNEL_CHECK_FALSE(value_shape_vec[i] > 0, KERNEL_STATUS_PARAM_INVALID, "The shape of input tensor is invalid.");
    }
    prefix *= value_shape_vec[i];
  }
  if (unpack_axis >= value_shape_vec.size()) {
    KERNEL_CHECK_FALSE(unpack_axis < value_shape_vec.size(),
                       KERNEL_STATUS_PARAM_INVALID, "input attr axis is invalid.");
  }
  int64_t midfix = value_shape_vec[unpack_axis];
  int64_t subfix = 1;
  for (size_t i = unpack_axis + 1; i < value_shape_vec.size(); i++) {
    if (value_shape_vec[i] == 0) {
      KERNEL_CHECK_FALSE(value_shape_vec[i] > 0, KERNEL_STATUS_PARAM_INVALID,
                         "The shape of input tensor is invalid.");
    }
    subfix *= value_shape_vec[i];
  }
  uint32_t min_core_num = 1;
  int64_t max_core_num = std::max(min_core_num, aicpu::CpuKernelUtils::GetCPUNum(ctx) - 2);
  if (max_core_num > unpack_num) {
    max_core_num = unpack_num;
  }

  auto shard_unpack = [&](size_t start, size_t end) {
    int64_t offset = 0;
    for (uint64_t i = start; i < end; i++) {
      offset = i * subfix;
      T *output_data_ptr = output_data_vec[i];
      T *input_copy_ptr = input_data_ptr + offset;
      auto copy_size = subfix * sizeof(T);
      for (int64_t j = 0; j < prefix; j++) {
        auto mem_ret = memcpy_s(output_data_ptr, copy_size, input_copy_ptr, copy_size);
        KERNEL_CHECK_FALSE_VOID((mem_ret == EOK),
            "Memcpy size[%zu] from input value to output[%ld] failed.", copy_size, i);
        input_copy_ptr += (subfix * midfix);
        output_data_ptr += subfix;
      }
    }
  };
  int64_t thread_num = max_core_num > 0 ? unpack_num / max_core_num : 1;
  KERNEL_HANDLE_ERROR(
      CpuKernelUtils::ParallelFor(ctx, unpack_num, thread_num, shard_unpack), "Unpack Compute failed.")
  return KERNEL_STATUS_OK;
}

template <typename T>
uint32_t UnpackCpuKernel::DoCompute(CpuKernelContext &ctx) {
  T *input_data_ptr = reinterpret_cast<T *>(value_data_ptr);
  std::vector<T *> output_data_vec;
  output_data_vec.resize(unpack_num);
  for (int64_t i = 0; i < unpack_num; i++) {
    output_data_vec[i] = reinterpret_cast<T *>(output_ptr_vec[i]);
  }
  if (unpack_num == 1) {
    KERNEL_CHECK_FALSE(
        (UnpackWithOneOutput<T>(input_data_ptr, output_data_vec) ==
         KERNEL_STATUS_OK),
        KERNEL_STATUS_PARAM_INVALID, "UnpackWithOneOutput failed.");
    return KERNEL_STATUS_OK;
  }
  if (unpack_axis == 0) {
    KERNEL_CHECK_FALSE((UnpackWithDimZero<T>(input_data_ptr, output_data_vec) ==
                        KERNEL_STATUS_OK),
                       KERNEL_STATUS_PARAM_INVALID,
                       "UnpackWithDimZero failed.");
    return KERNEL_STATUS_OK;
  }
  KERNEL_CHECK_FALSE((UnpackCompute<T>(input_data_ptr, output_data_vec, ctx) ==
                      KERNEL_STATUS_OK),
                     KERNEL_STATUS_PARAM_INVALID, "Unpack Compute failed.");
  return KERNEL_STATUS_OK;
}

uint32_t UnpackCpuKernel::Compute(CpuKernelContext &ctx) {
  KERNEL_CHECK_FALSE((CheckAndInitParams(ctx) == KERNEL_STATUS_OK),
                     KERNEL_STATUS_PARAM_INVALID, "CheckAndInitParams failed.");
  switch (data_type) {
    case DT_FLOAT16:
      return DoCompute<Eigen::half>(ctx);
    case DT_FLOAT:
      return DoCompute<float>(ctx);
    case DT_DOUBLE:
      return DoCompute<double>(ctx);
    case DT_BOOL:
      return DoCompute<bool>(ctx);
    case DT_INT8:
      return DoCompute<int8_t>(ctx);
    case DT_INT16:
      return DoCompute<int16_t>(ctx);
    case DT_INT32:
      return DoCompute<int32_t>(ctx);
    case DT_INT64:
      return DoCompute<int64_t>(ctx);
    case DT_UINT8:
      return DoCompute<uint8_t>(ctx);
    case DT_UINT16:
      return DoCompute<uint16_t>(ctx);
    case DT_UINT32:
      return DoCompute<uint32_t>(ctx);
    case DT_UINT64:
      return DoCompute<uint64_t>(ctx);
    case DT_COMPLEX64:
      return DoCompute<std::complex<float>>(ctx);
    case DT_COMPLEX128:
      return DoCompute<std::complex<double>>(ctx);
    default:
      KERNEL_LOG_ERROR("Unsupport data type [%s]", DTypeStr(data_type).c_str());
      return KERNEL_STATUS_PARAM_INVALID;
  }
}

REGISTER_CPU_KERNEL(kUnpack, UnpackCpuKernel);
}  // namespace aicpu