/**
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
 * This program is free software, you can redistribute it and/or modify it under the terms and conditions of 
 * CANN Open Software License Agreement Version 2.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, 
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

#include "graph/ge_error_codes.h"
#include "register/kernel_registry_impl.h"
#include "common/checker.h"
#include "exe_graph/runtime/tensor.h"
#include "exe_graph/runtime/continuous_vector.h"
#include "core/debug/kernel_tracing.h"
#include "kernel/common_kernel_impl/calc_tenorsize_from_shape.h"
#include "runtime/rt_ffts_plus.h"
#include "runtime/rt_ffts_plus_define.h"
#include "exe_graph/runtime/gert_tensor_data.h"
#include "kernel/memory/ffts_mem_allocator.h"

namespace gert {
namespace kernel {
struct SdmaDataLenHolder {
  size_t tail_len;
  size_t non_tail_len;
};

enum class CalcLenKey { NON_TAIL_SLICE = 0, TAIL_SLICE = 1, DTYPE = 2, SLICE_FLAG = 3, THREAD_DUM = 4, RESERVED };

enum class SdmaUpdateKey {
  CTX_ID = 0,
  THREAD_DIM = 1,
  WINDOW_SIZE = 2,
  SDMA_LEN = 3,
  INPUT_MEM_TYPE = 4,
  OUTPUT_MEM_TYPE = 5,
  INPUT_TENSOR = 6,
  OUTPUT_TENSOR = 7,
  RESERVED
};

ge::graphStatus SdmaUpdateContext(KernelContext *const context) {
  GELOGD("SdmaUpdateContext begin.");
  auto ctx_ids = context->GetInputPointer<gert::ContinuousVector>(static_cast<size_t>(SdmaUpdateKey::CTX_ID));
  auto thread_dim = context->GetInputValue<uint32_t>(static_cast<size_t>(SdmaUpdateKey::THREAD_DIM));
  auto window_size = context->GetInputValue<uint32_t>(static_cast<size_t>(SdmaUpdateKey::WINDOW_SIZE));
  auto sdma_len = context->GetInputPointer<SdmaDataLenHolder>(static_cast<size_t>(SdmaUpdateKey::SDMA_LEN));
  uint32_t input_mem_type = context->GetInputValue<uint32_t>(static_cast<uint32_t>(SdmaUpdateKey::INPUT_MEM_TYPE));
  uint32_t output_mem_type = context->GetInputValue<uint32_t>(static_cast<uint32_t>(SdmaUpdateKey::OUTPUT_MEM_TYPE));

  auto task_info = context->GetOutputPointer<rtFftsPlusTaskInfo_t>(0UL);

  GE_ASSERT_NOTNULL(sdma_len);
  GE_ASSERT_NOTNULL(task_info);
  GE_ASSERT_NOTNULL(task_info->descBuf);
  GE_ASSERT_NOTNULL(task_info->fftsPlusSqe);
  GE_ASSERT_NOTNULL(ctx_ids);

  GE_ASSERT_TRUE(thread_dim != 0, "Thread dim val [0] is invalid");

  rtFftsPlusSdmaCtx_t *const ctx_head = reinterpret_cast<rtFftsPlusSdmaCtx_t *>(const_cast<void *>(task_info->descBuf));
  auto ctx_id_vec = reinterpret_cast<const int32_t *>(ctx_ids->GetData());
  const size_t ctx_num = ctx_ids->GetSize();
  uint16_t total_num = task_info->fftsPlusSqe->totalContextNum;
  GE_ASSERT_TRUE(window_size <= ctx_num);
  for (uint32_t idx = 0U; idx < window_size; ++idx) {
    if (ctx_id_vec[idx] > total_num) {
      GELOGE(ge::FAILED, "Context id [%d] is invalid.", ctx_id_vec[idx]);
      return ge::FAILED;
    }
    auto ctx = reinterpret_cast<rtFftsPlusSdmaCtx_t *>(ctx_head + ctx_id_vec[idx]);
    GE_ASSERT_NOTNULL(ctx);
    ctx->threadDim = thread_dim;
    ctx->tailDataLength = sdma_len->tail_len;
    ctx->nonTailDataLength = sdma_len->non_tail_len;

    if (input_mem_type == 0U) {
      auto tensor_data = context->GetInputPointer<GertTensorData>(static_cast<uint32_t>(SdmaUpdateKey::INPUT_TENSOR));
      GE_ASSERT_NOTNULL(tensor_data);
      auto addr = static_cast<uint64_t>(reinterpret_cast<uintptr_t>(tensor_data->GetAddr()));
      ctx->sourceAddressBaseL = static_cast<uint32_t>(addr & 0xFFFFFFFFFU);
      ctx->sourceAddressBaseH = static_cast<uint32_t>(addr >> 32U);
      ctx->sourceAddressOffset = ctx->nonTailDataLength;
    } else {
      auto ffts_mem =
          context->GetInputPointer<memory::FftsMemBlock>(static_cast<uint32_t>(SdmaUpdateKey::INPUT_TENSOR));
      GE_ASSERT_NOTNULL(ffts_mem);
      auto addr = static_cast<uint64_t>(reinterpret_cast<uintptr_t>(ffts_mem->Addr(idx)));
      ctx->sourceAddressBaseL = static_cast<uint32_t>(addr & 0xFFFFFFFFFU);
      ctx->sourceAddressBaseH = static_cast<uint32_t>(addr >> 32U);
      ctx->sourceAddressOffset = 0U;
    }

    if (output_mem_type == 0U) {
      auto tensor_data = context->GetInputPointer<GertTensorData>(static_cast<uint32_t>(SdmaUpdateKey::OUTPUT_TENSOR));
      GE_ASSERT_NOTNULL(tensor_data);
      auto addr = static_cast<uint64_t>(reinterpret_cast<uintptr_t>(tensor_data->GetAddr()));
      ctx->destinationAddressBaseL = static_cast<uint32_t>(addr & 0xFFFFFFFFFU);
      ctx->destinationAddressBaseH = static_cast<uint32_t>(addr >> 32U);
      ctx->destinationAddressOffset = ctx->nonTailDataLength;
    } else {
      auto ffts_mem =
          context->GetInputPointer<memory::FftsMemBlock>(static_cast<uint32_t>(SdmaUpdateKey::OUTPUT_TENSOR));
      GE_ASSERT_NOTNULL(ffts_mem);
      auto addr = static_cast<uint64_t>(reinterpret_cast<uintptr_t>(ffts_mem->Addr(idx)));
      ctx->destinationAddressBaseL = static_cast<uint32_t>(addr & 0xFFFFFFFFFU);
      ctx->destinationAddressBaseH = static_cast<uint32_t>(addr >> 32U);
      ctx->destinationAddressOffset = 0U;
    }
  }
  GELOGD("SdmaUpdateContext end.");

  return ge::GRAPH_SUCCESS;
}

std::vector<std::string> SdmaContextTracer(const KernelContext *context) {
  auto window_size = context->GetInputValue<uint32_t>(static_cast<size_t>(SdmaUpdateKey::WINDOW_SIZE));
  auto task_info = context->GetOutputPointer<rtFftsPlusTaskInfo_t>(0UL);
  if (task_info == nullptr) {
    return {"task info is nullptr"};
  }
  rtFftsPlusSdmaCtx_t *const ctx_head = reinterpret_cast<rtFftsPlusSdmaCtx_t *>(const_cast<void *>(task_info->descBuf));
  if (ctx_head == nullptr) {
    return {"ctx_head is nullptr"};
  }

  uint16_t total_num = task_info->fftsPlusSqe->totalContextNum;
  auto ctx_ids = context->GetInputPointer<gert::ContinuousVector>(static_cast<size_t>(SdmaUpdateKey::CTX_ID));
  if (ctx_ids == nullptr) {
    return {"context ids is nullptr"};
  }
  auto ctx_id_vec = reinterpret_cast<const int32_t *>(ctx_ids->GetData());
  const size_t ctx_num = ctx_ids->GetSize();
  if (window_size > ctx_num) {
    return {"window size is invalid"};
  }

  std::vector<std::string> strs;
  std::stringstream ss;
  for (uint32_t idx = 0U; idx < window_size; ++idx) {
    if (ctx_id_vec[idx] > total_num) {
      return {"ctx id is out of range"};
    }
    auto ctx = reinterpret_cast<rtFftsPlusSdmaCtx_t *>(ctx_head + ctx_id_vec[idx]);
    if (ctx != nullptr) {
      ss << "idx:" << std::dec << idx << " src_offset:" << ctx->sourceAddressOffset
         << " dst_offset:" << ctx->destinationAddressOffset << std::hex << " src_addr:" << ctx->sourceAddressBaseH
         << " " << ctx->sourceAddressBaseL << " dst_addr:" << ctx->destinationAddressBaseH << " "
         << ctx->destinationAddressBaseL;
      strs.emplace_back(ss.str().c_str());
      ss.clear();
      ss.str("");
    }
  }

  return strs;
}

REGISTER_KERNEL(SdmaUpdateContext).RunFunc(SdmaUpdateContext).TracePrinter(SdmaContextTracer);

ge::graphStatus CalcFftsThreadDataLen(KernelContext *const context) {
  auto non_tail_slice =
      context->GetInputPointer<gert::ContinuousVector>(static_cast<size_t>(CalcLenKey::NON_TAIL_SLICE));
  GE_ASSERT_NOTNULL(non_tail_slice);
  auto tail_slice = context->GetInputPointer<gert::ContinuousVector>(static_cast<size_t>(CalcLenKey::TAIL_SLICE));
  GE_ASSERT_NOTNULL(tail_slice);
  auto dtype = context->GetInputValue<ge::DataType>(static_cast<size_t>(CalcLenKey::DTYPE));
  auto slice_flag = context->GetInputValue<bool>(static_cast<size_t>(CalcLenKey::SLICE_FLAG));
  auto thread_dim = context->GetInputValue<uint32_t>(static_cast<size_t>(CalcLenKey::THREAD_DUM));

  auto mem_size = context->GetOutputPointer<size_t>(0UL);
  GE_ASSERT_NOTNULL(mem_size);
  auto data_len_holder = context->GetOutputPointer<SdmaDataLenHolder>(1UL);
  GE_ASSERT_NOTNULL(data_len_holder);

  const auto non_tail_shape = reinterpret_cast<const Shape *>(non_tail_slice->GetData());
  GE_ASSERT_NOTNULL(non_tail_shape);
  uint64_t non_tail_size{0UL};
  GE_ASSERT_GRAPH_SUCCESS(CalcUnalignedTensorSizeByShape(*non_tail_shape, dtype, non_tail_size));
  data_len_holder->non_tail_len = non_tail_size;

  if (!slice_flag) {
    data_len_holder->non_tail_len /= thread_dim;
    data_len_holder->tail_len = data_len_holder->non_tail_len;
    *mem_size = data_len_holder->non_tail_len;
  } else {
    GE_ASSERT_EQ(non_tail_slice->GetSize(), tail_slice->GetSize());
    GE_ASSERT_TRUE(non_tail_slice->GetSize() > 0UL);

    const auto tail_shape = reinterpret_cast<const Shape *>(tail_slice->GetData());
    GE_ASSERT_NOTNULL(tail_shape);
    uint64_t tail_size{0UL};
    GE_ASSERT_GRAPH_SUCCESS(CalcUnalignedTensorSizeByShape(*tail_shape, dtype, tail_size));
    data_len_holder->tail_len = tail_size;
    *mem_size = std::max(data_len_holder->non_tail_len, data_len_holder->tail_len);
  }

  return ge::GRAPH_SUCCESS;
}

ge::graphStatus BuildSdmaDataLenHolder(const ge::FastNode *node, KernelContext *context) {
  (void)node;
  auto *addr_holder = new (std::nothrow) SdmaDataLenHolder();
  GE_ASSERT_NOTNULL(addr_holder);
  auto chain = context->GetOutput(1UL);
  GE_ASSERT_NOTNULL(chain);
  chain->SetWithDefaultDeleter(addr_holder);
  return ge::GRAPH_SUCCESS;
}

std::vector<std::string> SdmaCalcTracer(const KernelContext *context) {
  auto slice_flag = context->GetInputValue<bool>(static_cast<size_t>(CalcLenKey::SLICE_FLAG));
  auto thread_dim = context->GetInputValue<uint32_t>(static_cast<size_t>(CalcLenKey::THREAD_DUM));
  auto data_len_holder = context->GetOutputPointer<SdmaDataLenHolder>(1UL);
  std::stringstream ss;
  if (data_len_holder == nullptr) {
    return {"SdmaDataLenHolder is nullptr"};
  }
  ss << "slice_flag:" << slice_flag << "thread_dim:" << thread_dim << "non_tail_len:" << data_len_holder->non_tail_len
     << data_len_holder->tail_len;
  return {ss.str()};
}

REGISTER_KERNEL(CalcFftsThreadDataLen)
    .RunFunc(CalcFftsThreadDataLen)
    .OutputsCreator(BuildSdmaDataLenHolder)
    .TracePrinter(SdmaCalcTracer);
}  // namespace kernel
}  // namespace gert