// Copyright (c) 2024 Huawei Technologies Co., Ltd
// All rights reserved.
//
// Licensed under the BSD 3-Clause License  (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <ATen/native/ForeachUtils.h>

#include "op_plugin/utils/op_api_common.h"
#include "torch_npu/csrc/framework/utils/UtilForOpAdapter.h"
#include "op_plugin/utils/custom_functions/opapi/ForeachConstants.h"
#include "op_plugin/OpApiInterface.h"
#include "op_plugin/AclOpsInterface.h"
#include "third_party/acl/inc/acl/acl_rt.h"
#include "torch_npu/csrc/core/npu/interface/AclInterface.h"

namespace op_api {
namespace {
    bool should_group(const at::TensorList tensors, const at::TensorList tensors2)
    {
        if (tensors.empty() && tensors2.empty()) {
            return false;
        }

        const at::Tensor& base_tensor = !tensors.empty() ? tensors[0] : tensors2[0];
        const auto base_device = base_tensor.device();
        const auto base_dtype = base_tensor.scalar_type();

        bool has_dtype_mismatch = false;

        auto check_list = [&base_device, &base_dtype, &has_dtype_mismatch](const at::TensorList& list) -> bool {
            for (const auto& t : list) {
                if (t.device() != base_device) {
                    return false;
                }
                if (t.scalar_type() != base_dtype) {
                    has_dtype_mismatch = true;
                }
            }
            return true;
        };
        if (!check_list(tensors) || !check_list(tensors2)) {
            return false;
        }

        return has_dtype_mismatch;
    }
} // namespace anonymous

#if VERSION_BETWEEN(V2R1, VERSION_NEWEST)
const size_t SIZE_OF_NOT_INT = 4;
const size_t SIZE_OF_SHORT = 2;
using npu_preparation = at_npu::native::OpPreparation;
using npu_calcu_util = at_npu::native::CalcuOpUtil;

void exec_npu_cmd_copy(const at::TensorList dst, at::TensorList src, bool non_blocking)
{
    EXEC_NPU_CMD(aclnnForeachCopy, src, dst);
}

void split_and_exec_npu_cmd_copy(const at::TensorList dst, at::TensorList src, bool non_blocking)
{
    size_t tensor_count = src.size();
    size_t max_tensor_count = SINGLE_FOREACH_OP_TENSOR_COUNT;
    size_t loop_time = tensor_count / max_tensor_count;

    if (tensor_count <= max_tensor_count) {
        exec_npu_cmd_copy(dst, src, non_blocking);
        return;
    }
    for (size_t i = 0; i < loop_time; i++) {
        at::TensorList temp_src(src.data() + i * max_tensor_count, max_tensor_count);
        at::TensorList temp_dst(dst.data() + i * max_tensor_count, max_tensor_count);
        exec_npu_cmd_copy(temp_dst, temp_src, non_blocking);
    }

    size_t remaining_count = tensor_count % max_tensor_count;
    if (remaining_count != 0) {
        at::TensorList temp_src(src.data() + loop_time * max_tensor_count, remaining_count);
        at::TensorList temp_dst(dst.data() + loop_time * max_tensor_count, remaining_count);
        exec_npu_cmd_copy(temp_dst, temp_src, non_blocking);
    }
}

bool check_tensor_dtype_support_base(const at::TensorList src)
{
    if ((sizeof(src[0]) == SIZE_OF_NOT_INT && src[0].scalar_type() != at::ScalarType::QInt32) ||
         src[0].scalar_type() == at::ScalarType::Int) {
        return true;
    }
    if (sizeof(src[0]) == SIZE_OF_SHORT || src[0].scalar_type() == at::ScalarType::Short) {
        return true;
    }
    if (src[0].scalar_type() == at::ScalarType::Char || src[0].scalar_type() == at::ScalarType::Byte ||
        src[0].scalar_type() == at::ScalarType::BFloat16 ||
        src[0].scalar_type() == at::ScalarType::Float || src[0].scalar_type() == at::ScalarType::Half) {
        return true;
    } else if (op_plugin::utils::is_gte_cann_version_810rc1() && (src[0].scalar_type() == at::ScalarType::Long ||
            src[0].scalar_type() == at::ScalarType::Double || src[0].scalar_type() == at::ScalarType::Bool)) {
        return true;
    }
    return false;
}

bool check_tensor_device_dtype_base(const at::TensorList dsts, const at::TensorList srcs, bool non_blocking)
{
    if (dsts.size() != srcs.size() || dsts.size() == 0 || srcs.size() == 0) {
        return false;
    }
    // 方向一致校验
    const auto expected_dst_dtype =  dsts[0].device().type();
    const auto expected_src_dtype =  srcs[0].device().type();
    int expected_dst_device_index =  -1;
    int expected_src_device_index =  -1;
    // 如果是异步,需要做流校验,不支持跨卡集合copy
    int current_device_index = -1;
    if (non_blocking) {
        c10_npu::GetDevice(&current_device_index);
    }
    if (c10::DeviceType::PrivateUse1 == expected_dst_dtype) {
        expected_dst_device_index = dsts[0].device().index();
        if (non_blocking && current_device_index != expected_dst_device_index) {
            return false;
        }
    }
    if (c10::DeviceType::PrivateUse1 == expected_src_dtype) {
        expected_src_device_index = srcs[0].device().index();
        if (non_blocking && current_device_index != expected_src_device_index) {
            return false;
        }
    }
    for (const auto &dst: dsts) {
        if (dst.device().type() != expected_dst_dtype) {
            return false;
        }
        if (expected_dst_device_index != -1 && dst.device().index() != expected_dst_device_index) {
            return false;
        }
    }
    for (const auto &src: srcs) {
        if (src.device().type() != expected_src_dtype) {
            return false;
        }
        if (expected_src_device_index != -1 && src.device().index() != expected_src_device_index) {
            return false;
        }
    }
    // 排除d2d场景
    if (expected_dst_dtype == expected_src_dtype) {
        return false;
    }
    return true;
}

void memcpyBatch(const at::TensorList dst, at::TensorList src, bool non_blocking)
{
    TORCH_CHECK(dst.size() == src.size(), "dst and src size,must be equal but in realiry, the dst size is", dst.size(),
                " and the src size is ", dst.size(), "." + OPS_ERROR(ErrCode::PARAM));
    size_t count = dst.size();
    void *dsts[count];
    void *srcs[count];
    size_t dstLens[count];
    size_t srcLens[count];
    size_t attrsIndexes[count];
    aclrtMemcpyBatchAttr attrs[count];
    for (size_t i = 0; i < count; ++i) {
        aclrtMemcpyBatchAttr attr;
        aclrtMemLocation dstLoc;
        aclrtMemLocation srcLoc;
        at::Tensor dst_tensor = dst[i];
        at::Tensor src_tensor = src[i];
        // 获取 Tensor 的地址
        dsts[i] = dst_tensor.data_ptr();
        srcs[i] = src_tensor.data_ptr();
        // 计算 Tensor 的内存大小
        dstLens[i] = static_cast<size_t>(dst_tensor.numel() * dst_tensor.element_size());
        srcLens[i] = static_cast<size_t>(src_tensor.numel() * src_tensor.element_size());
        attrsIndexes[i] = i;
        // 判断哪个是d哪个是h
        if (dst_tensor.device().type() == c10::DeviceType::PrivateUse1) {
            int npu_device_index = dst_tensor.device().index();
            dstLoc.id = npu_device_index;
            dstLoc.type = aclrtMemLocationType::ACL_MEM_LOCATION_TYPE_DEVICE;
            attr.dstLoc = dstLoc;
            srcLoc.type = aclrtMemLocationType::ACL_MEM_LOCATION_TYPE_HOST;
            attr.srcLoc = srcLoc;
        };
        if (src_tensor.device().type() == c10::DeviceType::PrivateUse1) {
            int npu_device_index = src_tensor.device().index();
            srcLoc.id = npu_device_index;
            srcLoc.type = aclrtMemLocationType::ACL_MEM_LOCATION_TYPE_DEVICE;
            attr.srcLoc = srcLoc;
            dstLoc.type = aclrtMemLocationType::ACL_MEM_LOCATION_TYPE_HOST;
            attr.dstLoc = dstLoc;
        }
        constexpr uint32_t rsvMaxSize = sizeof(aclrtMemcpyBatchAttr::rsv) / sizeof(uint8_t);
        for (uint32_t j = 0U; j < rsvMaxSize; j++) {
            attr.rsv[j] = 0U;
        }
        attrs[i] = attr;
    }
    size_t failIdx = SIZE_MAX;
    auto acl_stream = c10_npu::getCurrentNPUStream().stream(false);
    if (non_blocking) {
        auto ret = c10_npu::queue::LaunchBatchAsyncCopyTask(dsts, dstLens, srcs, srcLens, count, attrs, attrsIndexes,
                                                            count, acl_stream);
        NPU_CHECK_ERROR(ret, "aclrtMemcpyBatchAsync");
    } else {
        aclError error = c10_npu::acl::AclrtSynchronizeStreamWithTimeout(acl_stream);
        if (error != ACL_ERROR_NONE) {
            CHECK_AND_THROW_ERROR_WITH_SPECIFIC_MESSAGE(error);
            C10_NPU_SHOW_ERR_MSG();
            if (c10_npu::option::OptionsManager::IsResumeModeEnable()) {
                TORCH_NPU_WARN("ACL stream synchronize failed, error code:", error,
                               ". But in checkpoint-resume mode will not throw exceptions.");
            } else {
                AT_ERROR("ACL stream synchronize failed, error code:", error);
            }
        }
        auto ret = c10_npu::acl::AclrtMemcpyBatch(dsts, dstLens, srcs, srcLens, count, attrs, attrsIndexes, count,
                                                  &failIdx);
        NPU_CHECK_ERROR(ret, "aclrtMemcpyBatch");
    }
}

bool is_supported_cross_dtype_copy(at::ScalarType src_dtype, at::ScalarType dst_dtype)
{
    if (src_dtype == at::ScalarType::Float &&
        (dst_dtype == at::ScalarType::Half || dst_dtype == at::ScalarType::BFloat16)) {
        return true;
    }
    if ((src_dtype == at::ScalarType::Half || src_dtype == at::ScalarType::BFloat16) &&
        dst_dtype == at::ScalarType::Float) {
        return true;
    }
    return false;
}

bool check_cross_dtype_supported(const at::TensorList dst, const at::TensorList src)
{
    if (!op_plugin::utils::is_gte_cann_version_910()) {
        return false;
    }
    if (dst.size() != src.size()) {
        return false;
    }
    for (size_t i = 0; i < src.size(); ++i) {
        if (!is_supported_cross_dtype_copy(src[i].scalar_type(), dst[i].scalar_type())) {
            return false;
        }
    }
    return true;
}

void process_tensor_list_batch(const at::TensorList self, const at::TensorList src, bool non_blocking,
    bool is_support_nd_out, bool is_support_batch)
{
    bool can_fast_route = at::native::can_use_fast_route(self, src);
    bool cross_dtype_supported = !can_fast_route && check_cross_dtype_supported(self, src);

    if (!is_support_nd_out || (!can_fast_route && !cross_dtype_supported) || !check_tensor_dtype_support_base(src)) {
        if (is_support_batch && ((non_blocking && c10_npu::acl::IsExistMemcpyBatchAsync()) ||
                (!non_blocking && c10_npu::acl::IsExistMemcpyBatch())) && check_tensor_device_dtype_base(self, src, non_blocking)) {
            return memcpyBatch(self, src, non_blocking);
        }
        ASCEND_LOGW(
            "The current situation does not support the use of the memcpyBatch interface in the foreach copy interface."
            "There may be the following reasons:1.SOC version is not supported; 2.CANN version is not supported;"
            "3.The direction of the tensor devices for srcs and dsts is inconsistent,"
            "and mixed H2D and D2H scenarios are not supported."
            "For example, all tensors on the srcList must be on the host or device side");
        return at::native::foreach_tensor_copy_list_kernel_slow_(self, src, non_blocking);
    }

    split_and_exec_npu_cmd_copy(self, src, /*non_blocking=*/true);
}

void _foreach_copy_(const at::TensorList self, const at::TensorList src, bool non_blocking)
{
    DO_COMPATIBILITY(aclnnForeachCopy, at::native::foreach_tensor_copy_list_kernel_slow_(self, src, non_blocking));
    at::native::check_foreach_api_restrictions(self, src);
    static const bool is_support_nd_out = (c10_npu::GetSocVersion() >= c10_npu::SocVersion::Ascend910B1 &&
                                          c10_npu::GetSocVersion() < c10_npu::SocVersion::Ascend310B1) ||
                                          (c10_npu::GetSocVersion() > c10_npu::SocVersion::Ascend310B4);
    static const bool is_support_batch = (c10_npu::GetSocVersion() >= c10_npu::SocVersion::Ascend910B1 &&
                                          c10_npu::GetSocVersion() < c10_npu::SocVersion::Ascend310B1) ||
                                          (c10_npu::GetSocVersion() > c10_npu::SocVersion::Ascend910_9391);

    if (!should_group(src, self)) {
        process_tensor_list_batch(self, src, non_blocking, is_support_nd_out, is_support_batch);
    } else {
        std::unordered_map<c10::ScalarType, std::pair<std::vector<at::Tensor>, std::vector<at::Tensor>>> temp_groups;
        for (size_t i = 0; i < src.size(); ++i) {
            temp_groups[src[i].scalar_type()].first.push_back(self[i]);
            temp_groups[src[i].scalar_type()].second.push_back(src[i]);
        }

        for (auto const& [dtype, vecs] : temp_groups) {
            at::TensorList current_self(vecs.first);
            at::TensorList current_src(vecs.second);

            process_tensor_list_batch(current_self, current_src, non_blocking, is_support_nd_out, is_support_batch);
        }
    }
}

#endif
} // namespace at_npu