/*
 * Copyright (c) 2024 Huawei Technologies Co., Ltd.
 * This program is free software, you can redistribute it and/or modify it under the terms and conditions of
 * CANN Open Software License Agreement Version 2.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */
#include "utils.h"
#include <thread>
#include <cstdlib>
#include <stdexcept>
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-function"
#include <torch_npu/csrc/core/npu/NPUStream.h>
#include <torch_npu/csrc/core/npu/DeviceUtils.h>
#include <torch_npu/csrc/core/npu/NPUFormat.h>
#pragma GCC diagnostic pop
#include "atb/utils/tensor_util.h"

namespace TorchAtb {
constexpr int32_t DECIMAL = 10;
namespace Utils {
aclrtStream GetCurrentStream()
{
    int32_t devId = 0;
    aclrtGetDevice(&devId);
    aclrtStream stream = c10_npu::getCurrentNPUStream(devId).stream();
    if (stream == nullptr) {
        throw std::runtime_error("get current npu stream fail");
    }
    return stream;
}

atb::Context *GetAtbContext()
{
    static thread_local std::shared_ptr<atb::Context> atbContext;
    if (atbContext) {
        return atbContext.get();
    }
    atb::Context *context = nullptr;
    atb::Status st = atb::CreateContext(&context);
    if (st != atb::NO_ERROR || !context) {
        throw std::runtime_error("create ATB context fail");
    }
    context->SetExecuteStream(GetCurrentStream());
    std::shared_ptr<atb::Context> shardContext(context, [](atb::Context *context) { atb::DestroyContext(context); });
    atbContext = shardContext;
    return atbContext.get();
}

atb::Tensor ConvertToAtbTensor(torch::Tensor &torchTensor)
{
    static const std::map<at::ScalarType, aclDataType> TORCH_TO_ACL_DTYPE_MAP = {
        {at::ScalarType::Bool, ACL_BOOL},    {at::ScalarType::Byte, ACL_UINT8},    {at::ScalarType::Char, ACL_INT8},
        {at::ScalarType::Half, ACL_FLOAT16}, {at::ScalarType::Float, ACL_FLOAT},   {at::ScalarType::Int, ACL_INT32},
        {at::ScalarType::Long, ACL_INT64},   {at::ScalarType::BFloat16, ACL_BF16}, {at::ScalarType::Short, ACL_INT16},
    };

    atb::Tensor atbTensor;
    if (!torchTensor.is_contiguous()) {
        torchTensor = torchTensor.contiguous();
    }
    if (!torchTensor.is_cpu()) {
        atbTensor.desc.format = static_cast<aclFormat>(at_npu::native::get_npu_format(torchTensor));
        atbTensor.deviceData = torchTensor.data_ptr();
    } else {
        atbTensor.hostData = torchTensor.data_ptr();
        atbTensor.desc.format = ACL_FORMAT_ND;
    }
    if (atbTensor.desc.format == ACL_FORMAT_NCHW) {
        atbTensor.desc.format = ACL_FORMAT_ND;
    }
    if (torchTensor.sizes().size() > atb::MAX_DIM) {
        throw std::runtime_error("tensor dimNum " + std::to_string(torchTensor.sizes().size()) +
                                 " is invalid, should be <= MAX_DIM(8)");
    }
    atbTensor.desc.shape.dimNum = torchTensor.sizes().size();
    for (uint64_t i = 0; i < torchTensor.sizes().size(); i++) {
        atbTensor.desc.shape.dims[i] = torchTensor.sizes()[i];
    }

    auto it = TORCH_TO_ACL_DTYPE_MAP.find(torchTensor.scalar_type());
    if (it != TORCH_TO_ACL_DTYPE_MAP.end()) {
        atbTensor.desc.dtype = it->second;
    }
    atbTensor.dataSize = atb::TensorUtil::CalcTensorDataSize(atbTensor);
    return atbTensor;
}

torch::Tensor CreateTorchTensorFromTensorDesc(const atb::TensorDesc &tensorDesc)
{
    static const std::map<aclDataType, at::ScalarType> ACL_TO_TORCH_DTYPE_MAP = {
        {ACL_BOOL, at::ScalarType::Bool},    {ACL_UINT8, at::ScalarType::Byte},    {ACL_INT8, at::ScalarType::Char},
        {ACL_FLOAT16, at::ScalarType::Half}, {ACL_FLOAT, at::ScalarType::Float},   {ACL_INT32, at::ScalarType::Int},
        {ACL_INT64, at::ScalarType::Long},   {ACL_BF16, at::ScalarType::BFloat16}, {ACL_INT16, at::ScalarType::Short},
    };
    at::TensorOptions options = at::TensorOptions();
    auto it = ACL_TO_TORCH_DTYPE_MAP.find(tensorDesc.dtype);
    if (it != ACL_TO_TORCH_DTYPE_MAP.end()) {
        options = options.dtype(it->second);
    }
    options = options.layout(torch::kStrided).requires_grad(false).device(torch_npu::utils::get_npu_device_type());
    torch::Tensor newTensor = at_npu::native::empty_with_format(
        at::IntArrayRef(tensorDesc.shape.dims, tensorDesc.shape.dimNum), options, tensorDesc.format);
    if (!newTensor.is_contiguous()) {
        newTensor = newTensor.contiguous();
    }
    return newTensor;
}
static bool IsEnvEnable(const char *envStr, bool defaultVal)
{
    const char *env = std::getenv(envStr);
    if (env == nullptr) {
        return defaultVal;
    }
    return strtol(env, nullptr, DECIMAL) != 0;
}

bool IsTaskQueueEnable()
{
    static bool isTaskQueueEnable =
        (!IsEnvEnable("ASCEND_LAUNCH_BLOCKING", false) && IsEnvEnable("TASK_QUEUE_ENABLE", true));
    return isTaskQueueEnable;
}
} // namespace Utils
} // namespace TorchAtb