#include <deque>
#include <cstdint>
#include <cstring>
#include <thread>
#include <utility>
#include <vector>

#include <ATen/ATen.h>

#include "torch_npu/csrc/npu/Graph.h"

#include "op_plugin/OpApiInterface.h"
#include "torch_npu/csrc/core/npu/NPUGraph.h"
#include "torch_npu/csrc/core/npu/NPUGraphsUtils.h"
#include "torch_npu/csrc/npu/Event.h"
#include "torch_npu/csrc/npu/Stream.h"

template <typename T>
using shared_ptr_class_ = py::class_<T, std::shared_ptr<T>>;
static std::map<c10_npu::NPUStream, std::vector<PyFuncStruct *>> callbacks = {};
constexpr int processReportTimeout = 100;
constexpr int pendingCallRetryCount = 16;
constexpr auto pendingCallRetryInterval = std::chrono::milliseconds(1);
static ThreadArgs* threadArgs = nullptr;
static uint64_t threadId = -1;
static std::mutex pendingCallbacksMutex;
using PendingCallbackEntry = std::pair<PendingCallPayload*, std::vector<at::Tensor>>;
static std::deque<PendingCallbackEntry> pendingCallbacksQueue;
static std::atomic<bool> pendingCallScheduled{false};

void *process_callback(void *arg)
{
    ThreadArgs* args = static_cast<ThreadArgs *>(arg);
    auto ret = aclrtSetCurrentContext(args->context);
    while (!args->exitFlag) {
        (void)aclrtProcessReport(processReportTimeout);
    }
    delete args;
    args = nullptr;
    return nullptr;
}

void LaunchCallFunc(void *userData)
{
    PyGILState_STATE state = PyGILState_Ensure();
    if (userData == nullptr) {
        return;
    }
    auto data = (PyFuncStruct *)(userData);
    PyObject *argslist = Py_BuildValue("(O)", data->pyFuncArgs);
    if (argslist == nullptr) {
        return;
    }
    PyObject *result = PyObject_CallObject(data->pyFunc, argslist);
    if (result == nullptr) {
        return;
    }
    if (argslist != nullptr) {
        Py_XDECREF(argslist);
    }
    if (result != nullptr) {
        Py_XDECREF(result);
    }
    PyGILState_Release(state);
}

namespace {

constexpr const char* kNpuGraphTensorPtrSpecMarker = "host_tensor_ptr";
constexpr const char* kNpuGraphTensorBufferSpecMarker = "host_tensor_buffer";

int PendingCallHandler(void *arg);

bool TrySchedulePendingCall()
{
    for (int retry = 0; retry < pendingCallRetryCount; ++retry) {
        if (Py_AddPendingCall(PendingCallHandler, nullptr) == 0) {
            return true;
        }
        std::this_thread::sleep_for(pendingCallRetryInterval);
    }
    return false;
}

void StartPendingCallRetryThread()
{
    std::thread([]() {
        while (!TrySchedulePendingCall()) {
            std::this_thread::sleep_for(pendingCallRetryInterval);
        }
    }).detach();
}

bool IsNpuGraphTensorPtrSpec(PyObject* obj)
{
    if (!PyTuple_Check(obj) || PyTuple_GET_SIZE(obj) != 5) {
        return false;
    }
    PyObject* marker = PyTuple_GET_ITEM(obj, 0);
    return PyUnicode_Check(marker) &&
        PyUnicode_CompareWithASCIIString(marker, kNpuGraphTensorPtrSpecMarker) == 0;
}

bool CollectNpuGraphTensorPtrSpec(PyObject* obj, PendingCallPayload* payload)
{
    PyObject* ptrObj = PyTuple_GET_ITEM(obj, 1);
    PyObject* nbytesObj = PyTuple_GET_ITEM(obj, 2);
    PyObject* shapeObj = PyTuple_GET_ITEM(obj, 3);
    PyObject* dtypeObj = PyTuple_GET_ITEM(obj, 4);

    auto dataPtr = static_cast<uintptr_t>(PyLong_AsUnsignedLongLong(ptrObj));
    if (PyErr_Occurred()) {
        return false;
    }
    auto nbytes = static_cast<Py_ssize_t>(PyLong_AsSsize_t(nbytesObj));
    if (PyErr_Occurred()) {
        return false;
    }
    if (nbytes < 0) {
        PyErr_SetString(PyExc_ValueError, "npugraph tensor buffer size must be non-negative");
        return false;
    }
    payload->pendingTensorData.emplace_back(dataPtr, nbytes, shapeObj, dtypeObj);
    return true;
}

bool CollectPendingTensorData(PyObject* obj, PendingCallPayload* payload)
{
    if (IsNpuGraphTensorPtrSpec(obj)) {
        return CollectNpuGraphTensorPtrSpec(obj, payload);
    }

    if (PyTuple_Check(obj)) {
        Py_ssize_t size = PyTuple_GET_SIZE(obj);
        for (Py_ssize_t i = 0; i < size; ++i) {
            if (!CollectPendingTensorData(PyTuple_GET_ITEM(obj, i), payload)) {
                return false;
            }
        }
        return true;
    }

    if (PyList_Check(obj)) {
        Py_ssize_t size = PyList_GET_SIZE(obj);
        for (Py_ssize_t i = 0; i < size; ++i) {
            if (!CollectPendingTensorData(PyList_GET_ITEM(obj, i), payload)) {
                return false;
            }
        }
        return true;
    }

    return true;
}

void CopyPendingTensorData(PendingCallPayload* payload, std::vector<at::Tensor>& copiedTensors)
{
    try {
        copiedTensors.clear();
        copiedTensors.reserve(payload->pendingTensorData.size());
        for (const auto& tensorData : payload->pendingTensorData) {
            at::Tensor copiedTensor = at::empty(
                {static_cast<int64_t>(tensorData.nbytes)},
                at::TensorOptions().dtype(at::kByte).device(at::kCPU));
            if (tensorData.nbytes > 0) {
                std::memcpy(
                    copiedTensor.data_ptr(),
                    reinterpret_cast<void*>(tensorData.dataPtr),
                    static_cast<size_t>(tensorData.nbytes));
            }
            copiedTensors.emplace_back(std::move(copiedTensor));
        }
    } catch (...) {
        copiedTensors.clear();
    }
}

PyObject* MaterializeNpuGraphTensorBufferSpec(
    PendingCallPayload* payload,
    const std::vector<at::Tensor>& copiedTensors,
    size_t& tensorDataIndex)
{
    if (tensorDataIndex >= payload->pendingTensorData.size() ||
        tensorDataIndex >= copiedTensors.size()) {
        PyErr_SetString(PyExc_RuntimeError, "npugraph tensor buffer index is out of range");
        return nullptr;
    }
    auto& tensorData = payload->pendingTensorData[tensorDataIndex];
    const auto& copiedTensor = copiedTensors[tensorDataIndex++];
    if (!copiedTensor.defined()) {
        PyErr_NoMemory();
        return nullptr;
    }

    PyObject* buffer = PyMemoryView_FromMemory(
        static_cast<char*>(copiedTensor.data_ptr()),
        tensorData.nbytes,
        PyBUF_WRITE);
    if (buffer == nullptr) {
        return nullptr;
    }

    PyObject* marker = PyUnicode_FromString(kNpuGraphTensorBufferSpecMarker);
    if (marker == nullptr) {
        Py_DECREF(buffer);
        return nullptr;
    }
    PyObject* bufferSpec = PyTuple_Pack(4, marker, buffer, tensorData.shape, tensorData.dtype);
    Py_DECREF(marker);
    Py_DECREF(buffer);
    return bufferSpec;
}

PyObject* MaterializePendingArg(
    PyObject* obj,
    PendingCallPayload* payload,
    const std::vector<at::Tensor>& copiedTensors,
    size_t& tensorDataIndex)
{
    if (IsNpuGraphTensorPtrSpec(obj)) {
        return MaterializeNpuGraphTensorBufferSpec(payload, copiedTensors, tensorDataIndex);
    }

    if (PyTuple_Check(obj)) {
        Py_ssize_t size = PyTuple_GET_SIZE(obj);
        PyObject* tuple = PyTuple_New(size);
        if (tuple == nullptr) {
            return nullptr;
        }
        for (Py_ssize_t i = 0; i < size; ++i) {
            PyObject* item = MaterializePendingArg(PyTuple_GET_ITEM(obj, i), payload, copiedTensors, tensorDataIndex);
            if (item == nullptr) {
                Py_DECREF(tuple);
                return nullptr;
            }
            PyTuple_SET_ITEM(tuple, i, item);
        }
        return tuple;
    }

    if (PyList_Check(obj)) {
        Py_ssize_t size = PyList_GET_SIZE(obj);
        PyObject* list = PyList_New(size);
        if (list == nullptr) {
            return nullptr;
        }
        for (Py_ssize_t i = 0; i < size; ++i) {
            PyObject* item = MaterializePendingArg(PyList_GET_ITEM(obj, i), payload, copiedTensors, tensorDataIndex);
            if (item == nullptr) {
                Py_DECREF(list);
                return nullptr;
            }
            PyList_SET_ITEM(list, i, item);
        }
        return list;
    }

    Py_INCREF(obj);
    return obj;
}

PyObject* MaterializePendingArgs(PendingCallPayload* payload, const std::vector<at::Tensor>& copiedTensors)
{
    size_t tensorDataIndex = 0;
    PyObject* materializedArgs = MaterializePendingArg(
        payload->pyFuncData.pyFuncArgs, payload, copiedTensors, tensorDataIndex);
    if (materializedArgs == nullptr) {
        PyErr_WriteUnraisable(payload->pyFuncData.pyFunc);
        PyErr_Clear();
        return nullptr;
    }
    return materializedArgs;
}

int PendingCallHandler(void *arg)
{
    std::deque<PendingCallbackEntry> readyCallbacks;
    {
        std::lock_guard<std::mutex> lock(pendingCallbacksMutex);
        pendingCallScheduled.store(false, std::memory_order_release);
        readyCallbacks.swap(pendingCallbacksQueue);
    }

    for (auto& callback : readyCallbacks) {
        auto* payload = callback.first;
        auto& copiedTensors = callback.second;
        if (payload == nullptr) {
            continue;
        }
        PyObject* materializedArgs = MaterializePendingArgs(payload, copiedTensors);
        if (materializedArgs == nullptr) {
            copiedTensors.clear();
            continue;
        }
        PyObject* result = PyObject_CallObject(payload->pyFuncData.pyFunc, materializedArgs);
        if (result != nullptr) {
            Py_XDECREF(result);
        } else {
            PyErr_WriteUnraisable(payload->pyFuncData.pyFunc);
        }
        Py_DECREF(materializedArgs);
        copiedTensors.clear();
    }

    bool shouldScheduleAgain = false;
    {
        std::lock_guard<std::mutex> lock(pendingCallbacksMutex);
        if (!pendingCallbacksQueue.empty() &&
            !pendingCallScheduled.exchange(true, std::memory_order_acq_rel)) {
            shouldScheduleAgain = true;
        }
    }

    if (shouldScheduleAgain && !TrySchedulePendingCall()) {
        StartPendingCallRetryThread();
    }

    return 0;
}

void LaunchCallbackViaPendingCall(void *userData)
{
    auto* payload = static_cast<PendingCallPayload*>(userData);
    if (payload == nullptr) {
        return;
    }

    std::vector<at::Tensor> copiedTensors;
    CopyPendingTensorData(payload, copiedTensors);

    bool shouldSchedule = false;
    {
        std::lock_guard<std::mutex> lock(pendingCallbacksMutex);
        pendingCallbacksQueue.emplace_back(payload, std::move(copiedTensors));
        if (!pendingCallScheduled.exchange(true, std::memory_order_acq_rel)) {
            shouldSchedule = true;
        }
    }

    if (shouldSchedule && !TrySchedulePendingCall()) {
        StartPendingCallRetryThread();
    }
}

} // namespace

class AclSkOptionHelper {
public:
    std::vector<aclskOption> optionsVec;
    std::deque<std::string> stringPool;
    std::deque<std::vector<char*>> ptrArrayPool;
    void processInitOption(const std::string& key, int value)
    {
        static const std::unordered_map<std::string, std::function<void(aclskOption&, int)>> optionHandlers = {
            {"preload_code", [](aclskOption& opt, int val) {
                opt.optionType = aclskOptionType::PRELOAD_CODE;
                opt.preload.preloadMode = static_cast<uint32_t>(val);
            }},
            {"split_mode", [](aclskOption& opt, int val) {
                opt.optionType = aclskOptionType::SPLIT_MODE;
                opt.splitMode.splitCnt = static_cast<uint32_t>(val);
            }},
            {"stream_fusion", [](aclskOption& opt, int val) {
                opt.optionType = aclskOptionType::STREAM_FUSION;
                opt.streamFusion.streamFusion = static_cast<uint32_t>(val);
            }},
            {"debug_sync_all", [](aclskOption& opt, int val) {
                opt.optionType = aclskOptionType::DEBUG_SYNC_ALL;
                opt.debugSync.debugSyncAll = static_cast<uint32_t>(val);
            }},
            {"constant_codegen", [](aclskOption& opt, int val) {
                opt.optionType = aclskOptionType::CONSTANT_CODEGEN;
                opt.constantCodegen.enableConstant = static_cast<uint32_t>(val);
            }},
            {"auto_op_parallel", [](aclskOption& opt, int val) {
                opt.optionType = aclskOptionType::AUTO_OP_PARALLEL;
                opt.autoOpParallel.enableAutoOpParallel = static_cast<uint32_t>(val);
            }},
            {"debug_op_exec_trace", [](aclskOption& opt, int val) {
                opt.optionType = aclskOptionType::DEBUG_OP_EXEC_TRACE;
                opt.debugOpExecTrace.enableOpExecTrace = static_cast<uint32_t>(val);
            }},
            {"debug_cross_core_sync_check", [](aclskOption& opt, int val) {
                opt.optionType = aclskOptionType::DEBUG_CROSS_CORE_SYNC_CHECK;
                opt.debugCrossCoreSyncCheck.enableCrossCoreSyncCheck = static_cast<uint32_t>(val);
            }},
            {"early_start", [](aclskOption& opt, int val) {
                opt.optionType = aclskOptionType::EARLY_START;
                opt.earlyStart.enableEarlyStart = static_cast<uint32_t>(val);
            }},
            {"debug_per_op_max_core_num", [](aclskOption& opt, int val) {
                opt.optionType = aclskOptionType::DEBUG_PER_OP_MAX_CORE_NUM;
                opt.debugPerOpMaxCoreNum.enableDebugPerOpMaxCoreNum = static_cast<uint32_t>(val);
            }}
        };

        auto it = optionHandlers.find(key);
        if (it != optionHandlers.end()) {
            aclskOption opt = {};
            it->second(opt, value);
            optionsVec.push_back(opt);
        }
    }

    void processStringArrayOption(const std::string& key, const std::vector<std::string>& values)
    {
        if (key == "dcci_disable_on_kernel") {
            processDcciDisableOnKernel(values);
        } else if (key == "dcci_before_kernel_start") {
            processDcciBeforeKernelStart(values);
        } else if (key == "dcci_after_kernel_end") {
            processDcciAfterKernelEnd(values);
        } else if (key == "ubuf_lock_ignore_kernel") {
            processUbufLockIgnoreKernel(values);
        }
    }

    uint32_t getDictIntDefault(const py::dict& d, const std::string& key, uint32_t default_val)
    {
        if (d.contains(key)) {
            py::object item = d.attr("__getitem__")(key);
            if (py::isinstance<py::int_>(item)) {
                return item.cast<uint32_t>();
            }
        }
        return default_val;
    }

    void processDictOption(const std::string& key, const py::dict& dict_value)
    {
        if (key == "aggressive_opt_strategies") {
            aclskOption opt = {};
            opt.optionType = aclskOptionType::AGGRESSIVE_OPT_STRATEGIES;
            opt.aggressiveOpts.eventBreakerBypass = getDictIntDefault(dict_value, "event_breaker_bypass", 0);
            opt.aggressiveOpts.valueBreakerBypass = getDictIntDefault(dict_value, "value_breaker_bypass", 0);
            opt.aggressiveOpts.taskBreakerBypass = getDictIntDefault(dict_value, "task_breaker_bypass", 0);
            optionsVec.push_back(opt);
        }
    }

    void processStringOption(const std::string& key, const std::string& value)
    {
        aclskOption opt = {};
        if (key == "opt_extend") {
            opt.optionType = aclskOptionType::OPT_EXTEND_OPTION;
            stringPool.push_back(value);
            opt.optExtend.value = const_cast<char*>(stringPool.back().c_str());
        } else if (key == "debug_extend") {
            opt.optionType = aclskOptionType::DEBUG_EXTEND_OPTION;
            stringPool.push_back(value);
            opt.debugExtend.value = const_cast<char*>(stringPool.back().c_str());
        } else {
            return;
        }
        optionsVec.push_back(opt);
    }

    aclskOptions getStruct()
    {
        aclskOptions finalOpt = {};
        finalOpt.options = optionsVec.data();
        finalOpt.numOptions = optionsVec.size();
        return finalOpt;
    }

private:
    void processDcciDisableOnKernel(const std::vector<std::string>& values)
    {
        aclskOption opt = {};
        opt.optionType = aclskOptionType::DCCI_DISABLE_ON_KERNEL;
        opt.disableKernelDcci.kernelCnt = static_cast<int>(values.size());
        opt.disableKernelDcci.kernelNames = convertStringArray(values);
        optionsVec.push_back(opt);
    }

    void processDcciBeforeKernelStart(const std::vector<std::string>& values)
    {
        aclskOption opt = {};
        opt.optionType = aclskOptionType::DCCI_BEFORE_KERNEL_START;
        opt.dcciBeforeKernelStart.kernelCnt = static_cast<int>(values.size());
        opt.dcciBeforeKernelStart.kernelNames = convertStringArray(values);
        optionsVec.push_back(opt);
    }

    void processDcciAfterKernelEnd(const std::vector<std::string>& values)
    {
        aclskOption opt = {};
        opt.optionType = aclskOptionType::DCCI_AFTER_KERNEL_END;
        opt.dcciAfterKernelEnd.kernelCnt = static_cast<int>(values.size());
        opt.dcciAfterKernelEnd.kernelNames = convertStringArray(values);
        optionsVec.push_back(opt);
    }

    void processUbufLockIgnoreKernel(const std::vector<std::string>& values)
    {
        aclskOption opt = {};
        opt.optionType = aclskOptionType::UBUF_LOCK_IGNORE_KERNEL;
        opt.ubufLockIgnoreKernel.ubufLockIgnoreKernelCnt = static_cast<int>(values.size());
        opt.ubufLockIgnoreKernel.ubufLockIgnoreKernel = convertStringArray(values);
        optionsVec.push_back(opt);
    }

    char** convertStringArray(const std::vector<std::string>& values)
    {
        std::vector<char*> charPtrs;
        charPtrs.reserve(values.size());
        for (const auto& token : values) {
            stringPool.push_back(token);
            charPtrs.push_back(const_cast<char*>(stringPool.back().c_str()));
        }
        ptrArrayPool.push_back(std::move(charPtrs));
        return ptrArrayPool.back().data();
    }
};

void TORCH_NPU_API THNPGraph_init(PyObject* module) {
    // Pybind11 patch notes say "py::module_" is more up-to-date syntax,
    // but CI linter and some builds prefer "module".
    auto torch_N_m = py::handle(module).cast<py::module>();

    py::class_<c10_npu::NPUTaskGroupHandle>(torch_N_m, "_NPUTaskGroupHandle")
            .def_readonly("task_group", &c10_npu::NPUTaskGroupHandle::task_group);

    torch_N_m.def("_graph_pool_handle", &c10_npu::graph_pool_handle)
        .def("_graph_task_group_begin", [](py::object py_stream) {
            auto stream = (*py_stream).ptr();
            c10_npu::graph_task_group_begin(THNPUtils_PyObject_to_NPUStream(stream));
            NPUGRAPH_LOGD("NPUGRAPH TaskGroup begin, stream=%p", static_cast<void*>(stream));
        })
        .def("_graph_task_group_end", [](py::object py_stream) {
            auto stream = (*py_stream).ptr();
            auto handle = c10_npu::graph_task_group_end(THNPUtils_PyObject_to_NPUStream(stream));
            NPUGRAPH_LOGD("NPUGRAPH TaskGroup end, handle=%p", static_cast<void*>(handle.task_group));
            return handle;
        })
        .def("_graph_task_update_begin", [](py::object py_stream, c10_npu::NPUTaskGroupHandle handle) {
            auto stream = (*py_stream).ptr();
            c10_npu::graph_task_update_begin(THNPUtils_PyObject_to_NPUStream(stream), handle);
            NPUGRAPH_LOGD("NPUGRAPH TaskGroup update begin, handle=%p", static_cast<void*>(handle.task_group));
        })
        .def("_graph_task_update_end", [](py::object py_stream) {
            auto stream = (*py_stream).ptr();
            c10_npu::graph_task_update_end(THNPUtils_PyObject_to_NPUStream(stream));
            NPUGRAPH_LOGD("NPUGRAPH TaskGroup update end");
        })
        .def("_super_kernel_scope_begin", [](const char* scope_name) {
            NPUGRAPH_LOGD("NPUGRAPH SuperKernel scope begin, name=%s",
                          scope_name ? scope_name : "(null)");
            c10_npu::super_kernel_scope_begin(scope_name);
        })
        .def("_super_kernel_scope_end", [](const char* scope_name) {
            NPUGRAPH_LOGD("NPUGRAPH SuperKernel scope end, name=%s",
                          scope_name ? scope_name : "(null)");
            c10_npu::super_kernel_scope_end(scope_name);
        })
        .def("_launch_host_func", [](py::object py_stream, py::object py_func, py::object py_data) {
            auto func = (*py_func).ptr();
            auto userDataList = (*py_data).ptr();
            auto stream = THNPUtils_PyObject_to_NPUStream((*py_stream).ptr());
            PyFuncStruct *data = new(std::nothrow) PyFuncStruct(func, userDataList);
            c10_npu::launch_callback(stream, LaunchCallFunc, data);
            callbacks[stream].emplace_back(data);
        })
        .def("_launch_host_func_pending", [](py::object py_stream, py::object py_func, py::object py_data) {
            auto func = (*py_func).ptr();
            auto userDataList = (*py_data).ptr();
            auto stream = THNPUtils_PyObject_to_NPUStream((*py_stream).ptr());
            auto payload = std::make_unique<PendingCallPayload>(func, userDataList);
            if (!CollectPendingTensorData(userDataList, payload.get())) {
                throw py::error_already_set();
            }
            c10_npu::launch_host_func(stream, LaunchCallbackViaPendingCall, payload.get());
            (void)payload.release();
        })
        .def("_subscribe_report", [](py::object py_stream) {
            auto stream = (*py_stream).ptr();
            aclrtContext context = aclrtContext();
            NPU_CHECK_ERROR(aclrtGetCurrentContext(&context));
            if ((threadArgs == nullptr) || (threadId == -1)) {
                threadArgs = new ThreadArgs(context, false);
                pthread_create(&threadId, nullptr, process_callback, threadArgs);
            }
            c10_npu::subscribe_report(threadId, THNPUtils_PyObject_to_NPUStream(stream));
        })
        .def("_unsubscribe_report", [](py::object py_stream) {
            auto stream = THNPUtils_PyObject_to_NPUStream((*py_stream).ptr());
            c10_npu::unsubscribe_report(threadId, stream);
            auto it = callbacks.find(stream);
            if (it != callbacks.end()) {
                std::vector<PyFuncStruct *>& funcs = it->second;
                for (PyFuncStruct* func : funcs) {
                    delete func;
                    func = nullptr;
                }
                funcs.clear();
                callbacks.erase(it);
            }
            if (callbacks.empty()) {
                threadArgs->exitFlag = true;
                threadId = -1;
            }
        })
        .def("_npu_fused_infer_attention_score_out_graph", [](
            py::object py_stream,
            c10_npu::NPUTaskGroupHandle handle,
            py::object py_event,
            py::args args,
            py::kwargs kwargs
            ) -> std::tuple<at::Tensor &, at::Tensor &> {
                // 1. 定义parser
                static torch::PythonArgParser parser({
                    "npu_fused_infer_attention_score("
                    "Tensor query, Tensor key, Tensor value, *, "
                    "Tensor? pse_shift=None, "
                    "Tensor? atten_mask=None, "
                    "SymIntArrayRef actual_seq_lengths=None, "
                    "SymIntArrayRef actual_seq_lengths_kv=None, "
                    "Tensor? dequant_scale1=None, "
                    "Tensor? quant_scale1=None, "
                    "Tensor? dequant_scale2=None, "
                    "Tensor? quant_scale2=None, " // 10
                    "Tensor? quant_offset2=None, "
                    "Tensor? antiquant_scale=None, "
                    "Tensor? antiquant_offset=None, "
                    "Tensor? key_antiquant_scale=None, "
                    "Tensor? key_antiquant_offset=None, "
                    "Tensor? value_antiquant_scale=None, "
                    "Tensor? value_antiquant_offset=None, "
                    "Tensor? block_table=None, "
                    "Tensor? query_padding_size=None, "
                    "Tensor? kv_padding_size=None, " // 20
                    "Tensor? key_shared_prefix=None, "
                    "Tensor? value_shared_prefix=None, "
                    "SymIntArrayRef actual_shared_prefix_len=None, "
                    "Tensor? query_rope=None, "
                    "Tensor? key_rope=None, "
                    "Tensor? key_rope_antiquant_scale=None, "
                    "int64_t num_heads=1, "
                    "double scale=1.0, "
                    "int64_t pre_tokens=2147483647, "
                    "int64_t next_tokens=2147483647, " // 30
                    "std::string input_layout=\"BSH\", "
                    "int64_t num_key_value_heads=0, "
                    "int64_t sparse_mode=0, "
                    "int64_t inner_precise=0, "
                    "int64_t block_size=0, "
                    "int64_t antiquant_mode=0, "
                    "int64_t key_antiquant_mode=0, "
                    "int64_t value_antiquant_mode=0, "
                    "bool softmax_lse_flag=False, "
                    "Tensor? workspace=None, " // 40
                    "TensorList out)"
                });

                // 2. 转换为原生 PyObject*
                PyObject* args_ptr = args.ptr();
                PyObject* kwargs_ptr = kwargs.ptr();

                // 3. 解析参数
                torch::ParsedArgs<42> parsed;
                torch::PythonArgs py_args = parser.parse(args_ptr, kwargs_ptr, parsed);

                // 4. 必选参数
                at::Tensor query = py_args.tensor(0);
                at::Tensor key = py_args.tensor(1);
                at::Tensor value = py_args.tensor(2);

                // 5. 可选参数
                c10::optional<at::Tensor> pse_shift = py_args.optionalTensor(3);
                c10::optional<at::Tensor> atten_mask = py_args.optionalTensor(4);
                c10::OptionalArray<c10::SymInt> actual_seq_lengths = py_args.symintlistOptional(5);
                c10::OptionalArray<c10::SymInt> actual_seq_lengths_kv = py_args.symintlistOptional(6);
                c10::optional<at::Tensor> dequant_scale1 = py_args.optionalTensor(7);
                c10::optional<at::Tensor> quant_scale1 = py_args.optionalTensor(8);
                c10::optional<at::Tensor> dequant_scale2 = py_args.optionalTensor(9);
                c10::optional<at::Tensor> quant_scale2 = py_args.optionalTensor(10);
                c10::optional<at::Tensor> quant_offset2 = py_args.optionalTensor(11);
                c10::optional<at::Tensor> antiquant_scale = py_args.optionalTensor(12);
                c10::optional<at::Tensor> antiquant_offset = py_args.optionalTensor(13);
                c10::optional<at::Tensor> key_antiquant_scale = py_args.optionalTensor(14);
                c10::optional<at::Tensor> key_antiquant_offset = py_args.optionalTensor(15);
                c10::optional<at::Tensor> value_antiquant_scale = py_args.optionalTensor(16);
                c10::optional<at::Tensor> value_antiquant_offset = py_args.optionalTensor(17);
                c10::optional<at::Tensor> block_table = py_args.optionalTensor(18);
                c10::optional<at::Tensor> query_padding_size = py_args.optionalTensor(19);
                c10::optional<at::Tensor> kv_padding_size = py_args.optionalTensor(20);
                c10::optional<at::Tensor> key_shared_prefix = py_args.optionalTensor(21);
                c10::optional<at::Tensor> value_shared_prefix = py_args.optionalTensor(22);
                c10::OptionalArray<c10::SymInt> actual_shared_prefix_len = py_args.symintlistOptional(23);
                c10::optional<at::Tensor> query_rope = py_args.optionalTensor(24);
                c10::optional<at::Tensor> key_rope = py_args.optionalTensor(25);
                c10::optional<at::Tensor> key_rope_antiquant_scale = py_args.optionalTensor(26);
                int64_t num_heads = py_args.toInt64(27);
                double scale = py_args.toDouble(28);
                int64_t pre_tokens = py_args.toInt64(29);
                int64_t next_tokens = py_args.toInt64(30);
                std::string input_layout = py_args.string(31);
                int64_t num_key_value_heads = py_args.toInt64(32);
                int64_t sparse_mode = py_args.toInt64(33);
                int64_t inner_precise = py_args.toInt64(34);
                int64_t block_size = py_args.toInt64(35);
                int64_t antiquant_mode = py_args.toInt64(36);
                int64_t key_antiquant_mode = py_args.toInt64(37);
                int64_t value_antiquant_mode = py_args.toInt64(38);
                bool softmax_lse_flag = py_args.toBool(39);
                c10::optional<at::Tensor> workspace = py_args.optionalTensor(40);
                std::vector<at::Tensor> out = py_args.tensorlist(41);
                TORCH_CHECK(out.size() == 2,
                    "out must have 2 tensors (attention_out, softmax_lse), but got ",
                    out.size(), PTA_ERROR(ErrCode::PARAM));
                at::Tensor attention_out = out[0];
                at::Tensor softmax_lse = out[1];

                auto stream = THNPUtils_PyObject_to_NPUStream((*py_stream).ptr());
                auto event_ptr = THNPUtils_PyObject_to_NPUEvent((*py_event).ptr());
                pybind11::gil_scoped_release no_gil;

                c10_npu::graph_task_update_begin(stream, handle);

                auto fia_result = op_api::npu_fused_infer_attention_score_out_symint(
                    query, key, value,
                    pse_shift, atten_mask, actual_seq_lengths, actual_seq_lengths_kv,
                    dequant_scale1, quant_scale1, dequant_scale2, quant_scale2, quant_offset2,
                    antiquant_scale, antiquant_offset, key_antiquant_scale, key_antiquant_offset,
                    value_antiquant_scale, value_antiquant_offset, block_table,
                    query_padding_size, kv_padding_size, key_shared_prefix, value_shared_prefix,
                    actual_shared_prefix_len, query_rope, key_rope, key_rope_antiquant_scale,
                    num_heads, scale, pre_tokens, next_tokens, input_layout,
                    num_key_value_heads, sparse_mode, inner_precise, block_size,
                    antiquant_mode, key_antiquant_mode, value_antiquant_mode,
                    softmax_lse_flag, workspace, attention_out, softmax_lse
                );

                c10_npu::graph_task_update_end(stream);
                event_ptr->record(stream);
                return fia_result;
        });

    shared_ptr_class_<c10_npu::NPUGraph>(torch_N_m, "_NPUGraph")
        .def(py::init<>())
        .def(
            "capture_begin",
            [](c10_npu::NPUGraph& self,
               std::optional<c10_npu::MempoolId_t> pool_opt,
               std::string capture_error_mode,
			   bool report_shape) {
                aclmdlRICaptureMode capture_mode;
                c10_npu::MempoolId_t pool = pool_opt.has_value()
                    ? pool_opt.value() : c10_npu::MempoolId_t{0, 0};
                if (capture_error_mode == "global") {
                    capture_mode = aclmdlRICaptureMode::ACL_MODEL_RI_CAPTURE_MODE_GLOBAL;
                } else if (capture_error_mode == "thread_local") {
                    capture_mode = aclmdlRICaptureMode::ACL_MODEL_RI_CAPTURE_MODE_THREAD_LOCAL;
                } else if (capture_error_mode == "relaxed") {
                    capture_mode = aclmdlRICaptureMode::ACL_MODEL_RI_CAPTURE_MODE_RELAXED;
                } else {
                    TORCH_CHECK(
                        false,
                        "Unknown capture error mode. Expected `global`, `thread_local`, or `relaxed`, got ",
                        capture_error_mode);
                }
                return self.capture_begin(pool, capture_mode, report_shape);
            },
            py::arg("pool"),
            py::arg("capture_error_mode"),
            py::arg("report_shape"),
            py::call_guard<py::gil_scoped_release>())
        .def(
            "capture_end",
            torch::wrap_pybind_function_no_gil(&c10_npu::NPUGraph::capture_end))
        .def(
            "register_generator_state",
            [](c10_npu::NPUGraph& self, py::handle raw_generator) {
                auto generator = THPGenerator_Unwrap(raw_generator.ptr());
                // We've unwrapped Python object to C++ object,
                // so we could release GIL before calling into C++
                py::gil_scoped_release release;
                return self.register_generator_state(generator);
            },
            py::arg("generator"))
        .def(
            "replay",
            torch::wrap_pybind_function_no_gil(&c10_npu::NPUGraph::replay))
        .def(
            "reset",
            torch::wrap_pybind_function_no_gil(&c10_npu::NPUGraph::reset))
        .def(
            "pool",
            torch::wrap_pybind_function_no_gil(&c10_npu::NPUGraph::pool))
        .def(
            "debug_dump",
            torch::wrap_pybind_function_no_gil(&c10_npu::NPUGraph::debug_dump),
             py::arg("debug_path"))
        .def(
            "enable_debug_mode",
            torch::wrap_pybind_function_no_gil(&c10_npu::NPUGraph::enable_debug_mode))
        .def(
            "super_kernel_optimize",
            [](c10_npu::NPUGraph& self,
               py::object optimize_options,
               py::object debug_options) {
                AclSkOptionHelper helper;
                if (!optimize_options.is_none()) {
                    auto opts = optimize_options.cast<py::dict>();
                    for (auto item : opts) {
                        std::string key = py::str(item.first);
                        if (py::isinstance<py::dict>(item.second)) {
                            helper.processDictOption(key, item.second.cast<py::dict>());
                        } else if (py::isinstance<py::str>(item.second)) {
                            helper.processStringOption(key, item.second.cast<std::string>());
                        } else if (py::isinstance<py::list>(item.second)) {
                            helper.processStringArrayOption(key, item.second.cast<std::vector<std::string>>());
                        } else if (py::isinstance<py::int_>(item.second)) {
                            helper.processInitOption(key, item.second.cast<int>());
                        } 
                    }
                }

                if (!debug_options.is_none()) {
                    auto opts = debug_options.cast<py::dict>();
                    for (auto item : opts) {
                        std::string key = py::str(item.first);
                        if (py::isinstance<py::int_>(item.second)) {
                            helper.processInitOption(key, item.second.cast<int>());
                        } else if (py::isinstance<py::str>(item.second)) {
                            helper.processStringOption(key, item.second.cast<std::string>());
                        }
                    }
                }
                aclskOptions options = helper.getStruct();
                {
                    py::gil_scoped_release release;
                    return self.super_kernel_optimize(&options);
                }
            },
            py::arg("optimize_options"),
            py::arg("debug_options"));
}