/**
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
 * This program is free software, you can redistribute it and/or modify it under the terms and conditions of 
 * CANN Open Software License Agreement Version 2.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, 
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

#include "llm_datadist_wrapper.h"
#include "common/mem_utils.h"
#include "llm_tensor.h"
#include "common/llm_checker.h"

namespace py = pybind11;

namespace llm {
std::unique_ptr<LLMDataDist> LLMDataDistWrapper::llm_data_dist;

CopyCacheParam LLMDataDistWrapper::UnpackCopyCacheParam(CopyCacheParamTuple cache_param_tuple) {
  constexpr size_t kIndexDstCacheId = 0;
  constexpr size_t kIndexSrcCacheId = 1;
  constexpr size_t kIndexDstCBatchIndex = 2;
  constexpr size_t kIndexSrcCBatchIndex = 3;
  constexpr size_t kIndexOffset = 4;
  constexpr size_t kIndexSize = 5;
  constexpr size_t kIndexReqId = 6;
  constexpr size_t kIndexBlockInfos = 7;
  CopyCacheParam copy_cache_param;
  copy_cache_param.dst_cache_id = std::get<kIndexDstCacheId>(cache_param_tuple);
  copy_cache_param.src_cache_id = std::get<kIndexSrcCacheId>(cache_param_tuple);
  copy_cache_param.dst_batch_index = std::get<kIndexDstCBatchIndex>(cache_param_tuple);
  copy_cache_param.src_batch_index = std::get<kIndexSrcCBatchIndex>(cache_param_tuple);
  copy_cache_param.offset = std::get<kIndexOffset>(cache_param_tuple);
  copy_cache_param.size = std::get<kIndexSize>(cache_param_tuple);
  copy_cache_param.req_id = std::get<kIndexReqId>(cache_param_tuple);
  copy_cache_param.copy_block_infos = std::move(std::get<kIndexBlockInfos>(cache_param_tuple));
  return copy_cache_param;
}

CacheDesc LLMDataDistWrapper::UnpackCacheDesc(const CacheDescTuple &cache_desc_tuple) {
  constexpr size_t kIndexNumTensors = 0;
  constexpr size_t kIndexDataType = 1;
  constexpr size_t kIndexDimIndex = 2;
  constexpr size_t kIndexShape = 3;
  constexpr size_t kPlacement = 4;
  constexpr size_t kIsBlocks = 5;
  CacheDesc cache_desc{};
  cache_desc.num_tensors = std::get<kIndexNumTensors>(cache_desc_tuple);
  cache_desc.data_type = static_cast<ge::DataType>(std::get<kIndexDataType>(cache_desc_tuple));
  cache_desc.seq_len_dim_index = std::get<kIndexDimIndex>(cache_desc_tuple);
  cache_desc.shape = std::get<kIndexShape>(cache_desc_tuple);
  cache_desc.placement = std::get<kPlacement>(cache_desc_tuple);
  auto is_blocks = std::get<kIsBlocks>(cache_desc_tuple);
  cache_desc.cache_mem_type = is_blocks == 0 ? CacheMemType::CACHE : CacheMemType::BLOCKS;
  return cache_desc;
}

CacheKey LLMDataDistWrapper::UnpackCacheKey(const CacheKeyTuple &cache_key_tuple) {
  constexpr size_t kIndexPromptClusterId = 0;
  constexpr size_t kIndexPromptCacheId = 1;
  constexpr size_t kIndexPromptBatchIndex = 2;
  constexpr size_t kIndexReqId = 3;
  constexpr size_t kIndexPrefixId = 4;
  constexpr size_t kIndexModelId = 5;
  constexpr size_t kIndexIsAllocateBlocks = 6;
  CacheKey cache_key{};
  cache_key.prompt_cluster_id = std::get<kIndexPromptClusterId>(cache_key_tuple);
  cache_key.prompt_cache_id = std::get<kIndexPromptCacheId>(cache_key_tuple);
  cache_key.prompt_batch_index = std::get<kIndexPromptBatchIndex>(cache_key_tuple);
  cache_key.req_id = std::get<kIndexReqId>(cache_key_tuple);
  cache_key.prefix_id = std::get<kIndexPrefixId>(cache_key_tuple);
  cache_key.model_id = std::get<kIndexModelId>(cache_key_tuple);
  cache_key.is_allocate_blocks = std::get<kIndexIsAllocateBlocks>(cache_key_tuple);
  return cache_key;
}

Cache LLMDataDistWrapper::UnpackCacheTuple(const CacheTuple &cache_tuple) {
  constexpr size_t kIndexCacheId = 0;
  constexpr size_t kIndexAddrs = 1;
  Cache cache{};
  cache.cache_id = std::get<kIndexCacheId>(cache_tuple);
  cache.per_device_tensor_addrs = std::move(std::get<kIndexAddrs>(cache_tuple));
  return cache;
}

std::vector<CacheKey> LLMDataDistWrapper::UnpackCacheKeys(const std::vector<CacheKeyTuple> &cache_key_tuples) {
  std::vector<CacheKey> cache_keys;
  cache_keys.reserve(cache_key_tuples.size());
  for (const auto &cache_key_tuple : cache_key_tuples) {
    cache_keys.emplace_back(UnpackCacheKey(cache_key_tuple));
  }
  return cache_keys;
}

LLMMemInfo LLMDataDistWrapper::UnpackMemInfo(const MemInfoTuple &mem_info_tuple) {
  constexpr size_t kIndexMemType = 0;
  constexpr size_t kIndexAddr = 1;
  constexpr size_t kIndexSize = 2;
  LLMMemInfo mem_info{};
  mem_info.mem_type = static_cast<LLMMemType>(std::get<kIndexMemType>(mem_info_tuple));
  mem_info.addr = std::get<kIndexAddr>(mem_info_tuple);
  mem_info.size = std::get<kIndexSize>(mem_info_tuple);
  return mem_info;
}

std::vector<LLMMemInfo> LLMDataDistWrapper::UnpackMemInfos(const std::vector<MemInfoTuple> &mem_info_tuples) {
  std::vector<LLMMemInfo> mem_infos;
  mem_infos.reserve(mem_info_tuples.size());
  for (const auto &mem_info_tuple : mem_info_tuples) {
    mem_infos.emplace_back(UnpackMemInfo(mem_info_tuple));
  }
  return mem_infos;
}

PullCacheParam LLMDataDistWrapper::UnpackPullCacheParam(const PullCacheParamTuple &pull_cache_param_tuple) {
  constexpr size_t kIndexSize = 0;
  constexpr size_t kIndexBatchIndex = 1;
  constexpr size_t kIndexPromptBlocks = 2;
  constexpr size_t kIndexDecoderBlocks = 3;
  constexpr size_t kIndexSrcTensorIndices = 4;
  constexpr size_t kIndexDstTensorIndices = 5;
  constexpr size_t kIndexSrcCacheOffset = 6;
  constexpr size_t kIndexDstCacheOffset = 7;
  constexpr size_t kIndexTensorNumPerLayerIndex = 8;
  PullCacheParam pull_cache_param{};
  pull_cache_param.size = std::get<kIndexSize>(pull_cache_param_tuple);
  pull_cache_param.batch_index = std::get<kIndexBatchIndex>(pull_cache_param_tuple);
  pull_cache_param.prompt_blocks = std::get<kIndexPromptBlocks>(pull_cache_param_tuple);
  pull_cache_param.decoder_blocks = std::get<kIndexDecoderBlocks>(pull_cache_param_tuple);
  pull_cache_param.src_tensor_indices = std::get<kIndexSrcTensorIndices>(pull_cache_param_tuple);
  pull_cache_param.dst_tensor_indices = std::get<kIndexDstTensorIndices>(pull_cache_param_tuple);
  pull_cache_param.src_cache_offset = std::get<kIndexSrcCacheOffset>(pull_cache_param_tuple);
  pull_cache_param.dst_cache_offset = std::get<kIndexDstCacheOffset>(pull_cache_param_tuple);
  pull_cache_param.tensor_num_per_layer = std::get<kIndexTensorNumPerLayerIndex>(pull_cache_param_tuple);
  return pull_cache_param;
}

TransferCacheConfig LLMDataDistWrapper::UnpackTransferCacheConfig(
    const TransferCacheConfigTuple &transfer_cache_config_tuple) {
  constexpr size_t kIndexCacheId = 0;
  constexpr size_t kIndexBatchIndex = 1;
  constexpr size_t kIndexLayerIndex = 2;
  constexpr size_t kIndexDstAddrs = 3;
  constexpr size_t kIndexClusterId = 4;
  constexpr size_t kIndexModelId = 5;
  constexpr size_t kIndexDstBatchIndex = 6;
  constexpr size_t kIndexType = 7;
  constexpr size_t kIndexDstLayerIndex = 8;
  constexpr size_t kIndexTensorNumPerLayerIndex = 9;
  TransferCacheConfig transfer_cache_config{};
  transfer_cache_config.src_cache_id = std::get<kIndexCacheId>(transfer_cache_config_tuple);
  transfer_cache_config.batch_index = std::get<kIndexBatchIndex>(transfer_cache_config_tuple);
  transfer_cache_config.layer_index = std::get<kIndexLayerIndex>(transfer_cache_config_tuple);
  transfer_cache_config.dst_addrs = std::get<kIndexDstAddrs>(transfer_cache_config_tuple);
  transfer_cache_config.cluster_id = std::get<kIndexClusterId>(transfer_cache_config_tuple);
  transfer_cache_config.model_id_or_cache_id = std::get<kIndexModelId>(transfer_cache_config_tuple);
  transfer_cache_config.dst_batch_index = std::get<kIndexDstBatchIndex>(transfer_cache_config_tuple);
  transfer_cache_config.type = std::get<kIndexType>(transfer_cache_config_tuple);
  transfer_cache_config.dst_layer_index = std::get<kIndexDstLayerIndex>(transfer_cache_config_tuple);
  transfer_cache_config.tensor_num_per_layer = std::get<kIndexTensorNumPerLayerIndex>(transfer_cache_config_tuple);
  return transfer_cache_config;
}
TransferBlockConfig LLMDataDistWrapper::UnpackTransferBlockConfig(
    const TransferBlockConfigTuple &transfer_block_config_tuple) {
  constexpr size_t kIndexBlockMemSize = 0;
  constexpr size_t kIndexSrcBlocks = 1;
  constexpr size_t kIndexDstBlocks = 2;
  TransferBlockConfig transfer_block_config{};
  transfer_block_config.block_mem_size = std::get<kIndexBlockMemSize>(transfer_block_config_tuple);
  transfer_block_config.src_blocks = std::get<kIndexSrcBlocks>(transfer_block_config_tuple);
  transfer_block_config.dst_blocks = std::get<kIndexDstBlocks>(transfer_block_config_tuple);
  return transfer_block_config;
}

ge::Status LLMDataDistWrapper::Init(uint64_t cluster_id, const std::map<std::string, std::string> &options) {
  LLM_CHK_BOOL_RET_STATUS(llm_data_dist == nullptr, ge::FAILED, "Repeat Init");
  auto instance = llm::MakeUnique<LLMDataDist>(cluster_id);
  LLM_CHECK_NOTNULL(instance);
  std::map<ge::AscendString, ge::AscendString> engine_options;
  for (const auto &option : options) {
    (void) engine_options.emplace(option.first.c_str(), option.second.c_str());
  }
  LLM_CHK_STATUS_RET(instance->LLMDataDistInitialize(engine_options));
  llm_data_dist = std::move(instance);
  return ge::SUCCESS;
}

void LLMDataDistWrapper::Finalize() {
  if (llm_data_dist != nullptr) {
    llm_data_dist->LLMDataDistFinalize();
    llm_data_dist.reset();
  }
}

ge::Status LLMDataDistWrapper::CheckLinkStatus(uint64_t remote_cluster_id) {
  return llm_data_dist->CheckLinkStatus(remote_cluster_id);
}

std::pair<ge::Status, std::vector<ge::Status>> LLMDataDistWrapper::LinkClusters(
    const std::vector<ClusterInfoTuple> &clusters, int32_t timeout) {
  ge::Status ret = ge::FAILED;
  std::vector<ge::Status> rets;
  if (llm_data_dist != nullptr) {
    auto cluster_infos = UnpackClusterInfos(clusters);
    ret = llm_data_dist->LinkClusters(cluster_infos, rets, timeout);
  }
  return {ret, rets};
}

std::pair<ge::Status, std::vector<ge::Status>> LLMDataDistWrapper::UnlinkClusters(
    const std::vector<ClusterInfoTuple> &clusters, int32_t timeout, bool force_flag) {
  ge::Status ret = ge::FAILED;
  std::vector<ge::Status> rets;
  if (llm_data_dist != nullptr) {
    auto cluster_infos = UnpackClusterInfos(clusters);
    ret = llm_data_dist->UnlinkClusters(cluster_infos, rets, timeout, force_flag);
  }
  return {ret, rets};
}

std::pair<ge::Status, CacheTuple> LLMDataDistWrapper::AllocateCache(const CacheDescTuple &cache_desc,
                                                                    const std::vector<CacheKeyTuple> &cache_keys) {
  ge::Status ret = ge::FAILED;
  CacheTuple result;
  if (llm_data_dist != nullptr) {
    Cache cache;
    ret = llm_data_dist->AllocateCache(UnpackCacheDesc(cache_desc), cache, UnpackCacheKeys(cache_keys));
    result = std::make_tuple(cache.cache_id, std::move(cache.per_device_tensor_addrs));
  }
  return {ret, result};
}

ge::Status LLMDataDistWrapper::DeallocateCache(int64_t cache_id) {
  LLM_CHECK_NOTNULL(llm_data_dist);
  return llm_data_dist->DeallocateCache(cache_id);
}

ge::Status LLMDataDistWrapper::PullCache(int64_t cache_id,
                                         const CacheKeyTuple &cache_key,
                                         const PullCacheParamTuple &pull_cache_param) {
  LLM_CHECK_NOTNULL(llm_data_dist);
  return llm_data_dist->PullCache(cache_id, UnpackCacheKey(cache_key), UnpackPullCacheParam(pull_cache_param));
}

ge::Status LLMDataDistWrapper::TransferCache(uint64_t task_id,
                                             const TransferCacheConfigTuple &transfer_cache_config_tuple,
                                             const TransferBlockConfigTuple &transfer_block_config_tuple) {
  LLM_CHECK_NOTNULL(llm_data_dist);
  auto config = UnpackTransferCacheConfig(transfer_cache_config_tuple);
  config.type = 0U;
  return llm_data_dist->TransferCache(task_id, config,
                                      UnpackTransferBlockConfig(transfer_block_config_tuple));
}

ge::Status LLMDataDistWrapper::CopyCache(CopyCacheParamTuple copy_cache_param) {
  LLM_CHECK_NOTNULL(llm_data_dist);
  return llm_data_dist->CopyCache(UnpackCopyCacheParam(std::move(copy_cache_param)));
}

ge::Status LLMDataDistWrapper::RemoveCacheKey(const CacheKeyTuple &cache_key_tuple) {
  LLM_CHECK_NOTNULL(llm_data_dist);
  return llm_data_dist->RemoveCacheKey(UnpackCacheKey(cache_key_tuple));
}

ge::Status LLMDataDistWrapper::SwapBlocks(const CacheTuple &src, const CacheTuple &dst, const uint64_t block_size,
                                          const uint32_t type,
                                          const std::vector<std::pair<int64_t, int64_t>> &block_mapping) {
  LLM_CHECK_NOTNULL(llm_data_dist);
  return llm_data_dist->SwapBlocks(UnpackCacheTuple(src), UnpackCacheTuple(dst), block_size, type, block_mapping);
}

ge::Status LLMDataDistWrapper::SwitchRole(const std::string &role, std::map<std::string, std::string> &options) {
  LLM_CHECK_NOTNULL(llm_data_dist);
  return llm_data_dist->SwitchRole(role, options);
}


std::pair<ge::Status, std::vector<TensorIdAndDesc>> LLMDataDistWrapper::GetCachedTensor(int64_t cache_id,
                                                                                        int32_t tensor_index) {
  ge::Status ret = ge::FAILED;
  std::vector<TensorIdAndDesc> results;
  if (llm_data_dist != nullptr) {
    std::vector<ge::Tensor> tensors;
    ret = llm_data_dist->GetCacheTensors(cache_id, tensors, tensor_index);
    if (ret == ge::SUCCESS) {
      for (auto &tensor : tensors) {
        auto tensor_id = LLMTensor::AddTensor(tensor);
        auto tensor_tuple = std::make_tuple(tensor_id,
                                            static_cast<int32_t>(tensor.GetDataType()),
                                            tensor.GetTensorDesc().GetShape().GetDims());
        results.emplace_back(std::move(tensor_tuple));
      }
    }
  }
  return {ret, results};
}

std::vector<llm::ClusterInfo> LLMDataDistWrapper::UnpackClusterInfos(const std::vector<ClusterInfoTuple> &clusters) {
  constexpr size_t kIndexRemoteClusterId = 0;
  constexpr size_t kIndexRemoteRoleType = 1;
  constexpr size_t kIndexLocalIpInfos = 2;
  constexpr size_t kIndexRemoteIpInfos = 3;
  std::vector<llm::ClusterInfo> cluster_infos;
  for (const auto &cluster : clusters) {
    llm::ClusterInfo cluster_info;
    cluster_info.remote_cluster_id = std::get<kIndexRemoteClusterId>(cluster);
    cluster_info.remote_role_type = std::get<kIndexRemoteRoleType>(cluster);
    for (const auto &ip_and_port : std::get<kIndexLocalIpInfos>(cluster)) {
      IpInfo ip_info;
      ip_info.ip = ip_and_port.first;
      ip_info.port = ip_and_port.second;
      cluster_info.local_ip_infos.emplace_back(ip_info);
    }
    for (const auto &ip_and_port : std::get<kIndexRemoteIpInfos>(cluster)) {
      IpInfo ip_info;
      ip_info.ip = ip_and_port.first;
      ip_info.port = ip_and_port.second;
      cluster_info.remote_ip_infos.emplace_back(ip_info);
    }
    cluster_infos.emplace_back(std::move(cluster_info));
  }
  return cluster_infos;
}
}  // namespace llm