// Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved.

#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

#include <cstdio>
#include <iostream>
#include <memory>
#include <string>

#include "context_ext.h"

#include "kcal/core/mpc_operator_base.h"
#include "kcal/core/operator_factory.h"
#include "kcal/operator/kcal_pir.h"
#include "kcal/operator/kcal_psi.h"
#include "kcal/operator/kcal_psi_ub.h"
#include "kcal/utils/io.h"

namespace brpc {
DECLARE_uint64(max_body_size);
DECLARE_int64(socket_max_unwritten_bytes);
} // namespace brpc

namespace py = pybind11;

namespace kcal {

namespace {

constexpr int MAX_UNWRITTEN_DIV = 2;

inline KCAL_AlgorithmsType IntToAlgorithmsType(int v) { return static_cast<KCAL_AlgorithmsType>(v); }

inline DG_TeeMode IntToTeeMode(int v) { return static_cast<DG_TeeMode>(v); }

inline DG_DummyMode IntToDummyMode(int v) { return static_cast<DG_DummyMode>(v); }

inline DG_ShareType IntToShareType(int v) { return static_cast<DG_ShareType>(v); }

void FeedKcalInput(const py::list &pyList, io::Input *kcalInput)
{
    if (pyList.empty()) {
        return;
    }
    const auto &itemTemp = pyList[0];
    if (py::isinstance<py::str>(itemTemp)) {
        auto *dgString = new (std::nothrow) DG_String[pyList.size()];
        if (!dgString) {
            throw std::bad_alloc();
        }
        for (size_t i = 0; i < pyList.size(); ++i) {
            if (!PyUnicode_Check(pyList[i].ptr())) {
                throw std::runtime_error("need str");
            }

            Py_ssize_t sz;
            const char *utf8 = PyUnicode_AsUTF8AndSize(pyList[i].ptr(), &sz);
            if (!utf8) {
                throw std::bad_alloc();
            }

            dgString[i].str = new char[sz + 1];
            std::memcpy(dgString[i].str, utf8, sz + 1);
            dgString[i].size = static_cast<int>(sz) + 1;
        }
        DG_TeeInput **internalInput = kcalInput->GetSecondaryPointer();
        (*internalInput)->data.strings = dgString;
        (*internalInput)->size = pyList.size();
        (*internalInput)->dataType = MPC_STRING;
    } else {
        auto inData = std::make_unique<double[]>(pyList.size());
        for (size_t i = 0; i < pyList.size(); ++i) {
            if (py::isinstance<py::int_>(pyList[i]) || py::isinstance<py::float_>(pyList[i])) {
                inData[i] = pyList[i].cast<double>();
            } else {
                throw std::runtime_error("need number type");
            }
        }
        DG_TeeInput **internalInput = kcalInput->GetSecondaryPointer();
        (*internalInput)->data.doubleNumbers = inData.release();
        (*internalInput)->size = pyList.size();
        (*internalInput)->dataType = MPC_DOUBLE;
    }
}

void FeedKcalInputFile(const py::list &pyList, io::Input *kcalInput)
{
    if (pyList.empty()) {
        return;
    }
    const auto &itemTemp = pyList[0];
    if (py::isinstance<py::str>(itemTemp)) {
        auto *dgString = new (std::nothrow) DG_String[pyList.size()];
        if (!dgString) {
            throw std::bad_alloc();
        }
        for (size_t i = 0; i < pyList.size(); ++i) {
            if (!PyUnicode_Check(pyList[i].ptr())) {
                throw std::runtime_error("need str");
            }

            Py_ssize_t sz;
            const char *utf8 = PyUnicode_AsUTF8AndSize(pyList[i].ptr(), &sz);
            if (!utf8) {
                throw std::bad_alloc();
            }

            dgString[i].str = new char[sz + 1];
            std::memcpy(dgString[i].str, utf8, sz + 1);
            dgString[i].size = static_cast<int>(sz);
        }
        DG_TeeInput **internalInput = kcalInput->GetSecondaryPointer();
        (*internalInput)->data.strings = dgString;
        (*internalInput)->size = pyList.size();
        (*internalInput)->dataType = MPC_STRING;
    } else {
        auto inData = std::make_unique<double[]>(pyList.size());
        for (size_t i = 0; i < pyList.size(); ++i) {
            if (py::isinstance<py::int_>(pyList[i]) || py::isinstance<py::float_>(pyList[i])) {
                inData[i] = pyList[i].cast<double>();
            } else {
                throw std::runtime_error("need number type");
            }
        }
        DG_TeeInput **internalInput = kcalInput->GetSecondaryPointer();
        (*internalInput)->data.doubleNumbers = inData.release();
        (*internalInput)->size = pyList.size();
        (*internalInput)->dataType = MPC_DOUBLE;
    }
}

void FeedKcalPairList(const py::list &key, const py::list &value, io::KcalPairList *pairList)
{
    if (key.size() != value.size()) {
        // 打印日志
        throw std::runtime_error("key value size don't match");
    };
    auto size = key.size();
    pairList->Get()->dgPair = new (std::nothrow) DG_Pair[size];
    if (!pairList->Get()->dgPair) {
        throw std::bad_alloc();
    }
    size_t i = 0;
    for (i = 0; i < key.size(); ++i) {
        pairList->Get()->dgPair[i].key = new (std::nothrow) DG_String();
        pairList->Get()->dgPair[i].value = new (std::nothrow) DG_String();
        if (!pairList->Get()->dgPair[i].key || !pairList->Get()->dgPair[i].value) {
            pairList->Get()->size = i + 1;
            throw std::bad_alloc();
        }
        // 填充key
        {
            if (!PyUnicode_Check(key[i].ptr())) {
                throw std::runtime_error("need str");
            }
            Py_ssize_t sz;
            const char *utf8 = PyUnicode_AsUTF8AndSize(key[i].ptr(), &sz);
            if (!utf8) {
                throw std::bad_alloc();
            }
            pairList->Get()->dgPair[i].key->str = new char[sz + 1];
            std::memcpy(pairList->Get()->dgPair[i].key->str, utf8, sz + 1);
            pairList->Get()->dgPair[i].key->size = static_cast<int>(sz) + 1;
        }
        // 填充 value
        {
            if (!PyUnicode_Check(value[i].ptr())) {
                throw std::runtime_error("need str");
            }
            Py_ssize_t sz;
            const char *utf8 = PyUnicode_AsUTF8AndSize(value[i].ptr(), &sz);
            if (!utf8) {
                throw std::bad_alloc();
            }
            pairList->Get()->dgPair[i].value->str = new char[sz + 1];
            std::memcpy(pairList->Get()->dgPair[i].value->str, utf8, sz + 1);
            pairList->Get()->dgPair[i].value->size = static_cast<int>(sz) + 1;
        }
    }
    pairList->Get()->size = size;
}

void FeedPsiOutput(io::Output &kcalOutput, py::list &pyList, DG_TeeMode mode)
{
    auto *outPtr = kcalOutput.Get();
    for (size_t i = 0; i < outPtr->size; ++i) {
        if (mode == TEE_OUTPUT_INDEX) {
            pyList.append(outPtr->data.u64Numbers[i]);
        } else if (mode == TEE_OUTPUT_STRING) {
            pyList.append(outPtr->data.strings[i].str);
        }
    }
}

void FeedKcalOutput(io::Output &kcalOutput, py::list &pyList)
{
    auto *outPtr = kcalOutput.Get();
    auto dataType = kcalOutput.Get()->dataType;
    for (size_t i = 0; i < outPtr->size; ++i) {
        if (dataType == MPC_STRING) {
            pyList.append(outPtr->data.strings[i].str);
        } else if (dataType == MPC_INT) {
            pyList.append(outPtr->data.u64Numbers[i]);
        } else if (dataType == MPC_DOUBLE) {
            pyList.append(outPtr->data.doubleNumbers[i]);
        }
    }
}

} // namespace

class PyCallbackAdapter {
public:
    static int PySendCallback(const TeeNodeInfo &nodeInfo, const uint8_t *data, size_t dataLen,
                              const py::function &pySendFunc)
    {
        if (!data) {
            return 0;
        }
        try {
            py::dict nodeInfoDict;
            nodeInfoDict["nodeId"] = nodeInfo.nodeId;
            // zero-copy
            py::memoryview dataMview = py::memoryview::from_buffer(
                const_cast<uint8_t *>(data), {static_cast<py::ssize_t>(dataLen)}, {sizeof(uint8_t)});

            py::object result = pySendFunc(nodeInfoDict, dataMview);
            return result.cast<int>();
        } catch (const py::error_already_set &e) {
            py::print("Python send callback error:", e.what());
            return -1;
        } catch (const std::exception &e) {
            py::print("Send callback error:", e.what());
            return -1;
        }
    }

    static int PyRecvCallback(const TeeNodeInfo &nodeInfo, uint8_t *buffer, size_t maxLen,
                              const py::function &pyRecvFunc)
    {
        if (!buffer) {
            return 0;
        }
        try {
            py::dict nodeInfoDict;
            nodeInfoDict["nodeId"] = nodeInfo.nodeId;
            // zero-copy
            py::memoryview bufferMview =
                py::memoryview::from_buffer(buffer, {static_cast<py::ssize_t>(maxLen)}, {sizeof(uint8_t)}, false);

            py::object result = pyRecvFunc(nodeInfoDict, bufferMview);
            if (result.is_none()) {
                return -1;
            }
            return result.cast<int>();
        } catch (const py::error_already_set &e) {
            py::print("Python recv callback error:", e.what());
            return -1;
        } catch (const std::exception &e) {
            py::print("Recv callback error:", e.what());
            return -1;
        }
    }
};

void BindIoClasses(py::module_ &m)
{
    py::class_<io::MpcShare>(m, "MpcShare")
        .def(py::init<>())
        .def("size", &io::MpcShare::Size)
        .def("type", &io::MpcShare::Type);

    py::class_<io::MpcShareSet>(m, "MpcShareSet")
        .def(py::init<>())
        .def_static(
            "Create", [](const std::vector<io::MpcShare *> &shares) { return io::MpcShareSet::Create(shares); },
            py::return_value_policy::take_ownership)
        .def(
            "Get", [](io::MpcShareSet &self) -> DG_MpcShareSet * { return self.Get(); },
            py::return_value_policy::reference);

    py::class_<io::Input>(m, "Input")
        .def(py::init<>())
        .def(py::init<DG_TeeInput *>())
        .def(
            "create",
            [] {
                auto teeInput = std::make_unique<DG_TeeInput>();
                auto input = std::make_unique<io::Input>(teeInput.release());
                return input;
            },
            py::return_value_policy::take_ownership)
        .def("Set", &io::Input::Set)
        .def("Fill", &io::Input::Fill)
        .def("Size", &io::Input::Size);

    // Alias of Input
    m.attr("Output") = m.attr("Input");
}

void BindOtherOperators(py::module_ &m)
{
    // PSI
    py::class_<Psi, std::shared_ptr<Psi>>(m, "Psi")
        .def(py::init<std::shared_ptr<Context>>())
        .def("run", [](Psi &self, const py::list &input, py::list &output, int mode) -> int {
            DG_TeeMode teeMode = IntToTeeMode(mode);
            io::Input kcalInput(new DG_TeeInput());
            FeedKcalInput(input, &kcalInput);
            io::Output kcalOutput;
            int ret = self.Run(kcalInput, kcalOutput, teeMode);
            FeedPsiOutput(kcalOutput, output, teeMode);
            return ret;
        });

    py::class_<Pir, std::shared_ptr<Pir>>(m, "Pir")
        .def(py::init<std::shared_ptr<Context>>())
        .def("ServerPreProcess",
             [](Pir &self, const py::list &key, py::list &value) -> int {
                 std::unique_ptr<io::KcalPairList> kcalInput(io::KcalPairList::Create());
                 // build DG_PairList
                 FeedKcalPairList(key, value, kcalInput.get());
                 return self.ServerPreProcess(kcalInput->Get());
             })
        .def("ClientQuery", [](Pir &self, const py::list &input, py::list &output, int mode) -> int {
                 DG_DummyMode dummyMode = IntToDummyMode(mode);
                 io::Input kcalInput(new DG_TeeInput());
                 FeedKcalInput(input, &kcalInput);
                 io::Output kcalOutput;
                 int ret = self.ClientQuery(kcalInput, kcalOutput, dummyMode);
                 FeedKcalOutput(kcalOutput, output);
                 return ret;
             })
        .def("ClientQueryByFile", [](Pir &self, const py::list &input, py::list &output, int mode) -> int {
                 DG_DummyMode dummyMode = IntToDummyMode(mode);
                 io::Input kcalInput(new DG_TeeInput());
                 FeedKcalInputFile(input, &kcalInput);
                 io::Output kcalOutput;
                 int ret = self.ClientQuery(kcalInput, kcalOutput, dummyMode);
                 FeedKcalOutput(kcalOutput, output);
                 return ret;
             })
        .def("ServerAnswer", [](Pir &self) -> int { return self.ServerAnswer(); })
        .def("ServerPreProcess", [](Pir &self, const std::string &input, const std::string &outputPath) -> int {
            return self.ServerPreProcess(input, outputPath);
        }, py::arg("input"), py::arg("outputPath"))
        .def("ServerAnswer", [](Pir &self, const std::string &dataPath, int isDeleteCache) -> int {
            return self.ServerAnswer(dataPath, isDeleteCache);
        }, py::arg("dataPath"), py::arg("isDeleteCache"));
}

void BindMpcOperators(py::module_ &m)
{
    py::class_<MpcOperatorBase, std::shared_ptr<MpcOperatorBase>>(m, "OperatorBase")
        .def("GetType", &MpcOperatorBase::GetType)
        .def("run",
             [](MpcOperatorBase &self, const std::vector<io::MpcShare *> &shares, io::MpcShare *outShare) -> int {
                 auto shareSetPtr = io::MpcShareSet::Create(shares);
                 return self.Run(shareSetPtr, outShare);
             })
        .def("run", [](MpcOperatorBase &self,
                       const std::vector<std::string> &inputFilePaths,
                       std::string &outputFilePath) -> py::tuple {
                 int outputCount = 0;
                 int ret = self.Run(inputFilePaths, outputFilePath, outputCount);
                 return py::make_tuple(ret, outputCount);
             }, py::arg("inputFilePaths"), py::arg("outputFilePath"));

    py::class_<MakeShare, std::shared_ptr<MakeShare>>(m, "MakeShare")
        .def(py::init<std::shared_ptr<Context>>())
        .def("run", [](MakeShare &self, const py::list &input, int isRecvShare, io::MpcShare &share) {
            io::Input kcalInput(new DG_TeeInput());
            FeedKcalInput(input, &kcalInput);
            return self.Run(kcalInput, isRecvShare, &share);
        })
        .def("run", [](MakeShare &self,
                       const std::string &inputFilePath,
                       int isRecvShare,
                       const std::string &shareFilePath) -> py::tuple {
            int shareCount = 0;
            int ret = self.Run(inputFilePath, isRecvShare, shareFilePath, shareCount);
            return py::make_tuple(ret, shareCount);
        }, py::arg("inputFilePath"), py::arg("isRecvShare"), py::arg("shareFilePath"));

    py::class_<RevealShare, std::shared_ptr<RevealShare>>(m, "RevealShare")
        .def(py::init<std::shared_ptr<Context>>())
        .def("run", [](RevealShare &self, const io::MpcShare &share, py::list &output) {
            io::Output out;
            int ret = self.Run(&share, out);
            FeedKcalOutput(out, output);
            return ret;
        })
        .def("run", [](RevealShare &self,
                       const std::string &shareFilePath,
                       const std::string &outputFilePath) -> py::tuple {
            int outputCount = 0;
            int ret = self.Run(shareFilePath, outputFilePath, outputCount);
            return py::make_tuple(ret, outputCount);
        }, py::arg("shareFilePath"), py::arg("outputFilePath"));
}

void BindMpcContext(py::module_ &m)
{
    py::class_<ContextExt, std::shared_ptr<ContextExt>>(m, "Context")
        .def(py::init<>())
        .def_static("create", [](Config config, py::function sendCb, py::function recvCb) {
            auto cppSendCb = [sendCb](const TeeNodeInfo &nodeInfo, const uint8_t *data, size_t dataLen) {
                return PyCallbackAdapter::PySendCallback(nodeInfo, data, dataLen, sendCb);
            };

            auto cppRecvCb = [recvCb](const TeeNodeInfo &nodeInfo, uint8_t *buffer, size_t maxLen) {
                return PyCallbackAdapter::PyRecvCallback(nodeInfo, buffer, maxLen, recvCb);
            };
            return ContextExt::Create(config, cppSendCb, cppRecvCb);
        })
        .def_static("create_with_yacl", [](Config config, const yacl::link::ContextDesc &desc, size_t rank) {
            // 不自动连接
            auto linker = YaclLinker::Create(config, desc, rank);
            if (!linker) {
                throw std::runtime_error("Failed to create YaclLinker");
            }
            return ContextExt::CreateFromYaclLinker(config, linker);
        })
        .def_static("create_with_link_config",
            [](Config config, const yacl::link::ContextDesc &desc, size_t rank, bool log_details = false) {
                auto linker = YaclLinker::Create(config, desc, rank);
                if (!linker) {
                    throw std::runtime_error("Failed to create YaclLinker");
                }
                // 获取yacl context并调用ConnectToMesh进行网络连接
                auto yaclCtx = linker->GetYaclContext();
                if (!yaclCtx) {
                    throw std::runtime_error("Failed to get yacl context");
                }
                // 调用ConnectToMesh - 设置brpc参数并连接
                {
                    py::gil_scoped_release release;
                    std::cout << "[create_with_link_config] Calling ConnectToMesh..." << std::endl;
                    // 设置brpc参数 (参考libspu.cc)
                    brpc::FLAGS_max_body_size = std::numeric_limits<uint64_t>::max();
                    brpc::FLAGS_socket_max_unwritten_bytes = std::numeric_limits<int64_t>::max() / MAX_UNWRITTEN_DIV;
                    // 根据log_details选择日志级别
                    if (log_details) {
                        yaclCtx->ConnectToMesh(spdlog::level::info);
                    } else {
                        yaclCtx->ConnectToMesh(spdlog::level::debug);
                    }
                    std::cout << "[create_with_link_config] ConnectToMesh done" << std::endl;
                }
                // 创建ContextExt并返回
                return ContextExt::CreateFromYaclLinker(config, linker);
            }, py::arg("config"), py::arg("desc"), py::arg("rank"), py::arg("log_details") = false)
        // 获取yacl context
        .def("get_yacl_context", [](const std::shared_ptr<ContextExt> &self) { return self->GetYaclContext(); },
            py::return_value_policy::reference);
}

void BindMpcCreateInstance(py::module_ &m)
{
    m.def("create_psi", [](const std::shared_ptr<ContextExt> &context) -> std::shared_ptr<Psi> {
        return OperatorFactory::CreatePsi(context->GetKcalContext());
    });

    m.def("create_pir", [](const std::shared_ptr<ContextExt> &context) -> std::shared_ptr<Pir> {
        return OperatorFactory::CreatePir(context->GetKcalContext());
    });

    BindOtherOperators(m);

    m.def("create_make_share", [](const std::shared_ptr<ContextExt> &context) -> std::shared_ptr<MakeShare> {
        return OperatorFactory::CreateMakeShare(context->GetKcalContext());
    });

    m.def("create_reveal_share", [](const std::shared_ptr<ContextExt> &context) -> std::shared_ptr<RevealShare> {
        return OperatorFactory::CreateRevealShare(context->GetKcalContext());
    });

    py::class_<PsiUb, std::shared_ptr<PsiUb>>(m, "PsiUb")
        .def(py::init<std::shared_ptr<Context>>())
        .def("run", [](PsiUb &self, const std::string &inputFilePath, const std::string &outputFilePath) -> py::tuple {
            int outputCount = 0;
            int ret = self.Run(inputFilePath, outputFilePath, outputCount);
            return py::make_tuple(ret, outputCount);
        }, py::arg("inputFilePath"), py::arg("outputFilePath"));

    m.def("create_psi_ub", [](const std::shared_ptr<ContextExt> &context) -> std::shared_ptr<PsiUb> {
        return OperatorFactory::CreatePsiUb(context->GetKcalContext());
    });

    m.def("create_mpc", [](const std::shared_ptr<ContextExt> &context, int type) -> std::shared_ptr<MpcOperatorBase> {
        KCAL_AlgorithmsType algType = IntToAlgorithmsType(type);
        return OperatorFactory::CreateMpc(context->GetKcalContext(), algType);
    });
}

PYBIND11_MODULE(libkcal, m)
{
    m.doc() = "KCAL Python bindings.";

    py::class_<Config>(m, "Config")
        .def(py::init<>())
        .def_readwrite("nodeId", &Config::nodeId)
        .def_readwrite("fixBits", &Config::fixBits)
        .def_readwrite("threadCount", &Config::threadCount)
        .def_readwrite("worldSize", &Config::worldSize)
        .def_readwrite("useSMAlg", &Config::useSMAlg)
        .def_readwrite("chunkSize", &Config::chunkSize)
        .def_readwrite("bucketCount", &Config::bucketCount)
        .def_readwrite("tmpPath", &Config::tmpPath);

    py::class_<yacl::link::ContextDesc::Party>(m, "Party")
        .def(py::init<>())
        .def(py::init<const std::string &, const std::string &>())
        .def_readonly("id", &yacl::link::ContextDesc::Party::id)
        .def_readonly("host", &yacl::link::ContextDesc::Party::host);

    py::class_<yacl::link::ContextDesc>(m, "LinkDesc")
        .def(py::init<>())
        .def_readwrite("id", &yacl::link::ContextDesc::id)
        .def_readonly("parties", &yacl::link::ContextDesc::parties)
        .def_readwrite("connect_retry_times", &yacl::link::ContextDesc::connect_retry_times)
        .def_readwrite("connect_retry_interval_ms", &yacl::link::ContextDesc::connect_retry_interval_ms)
        .def_readwrite("recv_timeout_ms", &yacl::link::ContextDesc::recv_timeout_ms)
        .def_readwrite("http_max_payload_size", &yacl::link::ContextDesc::http_max_payload_size)
        .def_readwrite("http_timeout_ms", &yacl::link::ContextDesc::http_timeout_ms)
        .def_readwrite("throttle_window_size", &yacl::link::ContextDesc::throttle_window_size)
        .def_readwrite("link_type", &yacl::link::ContextDesc::link_type)
        .def(
            "add_party",
            [](yacl::link::ContextDesc &desc, std::string id, std::string host) {
                desc.parties.push_back({std::move(id), std::move(host)});
            },
            "add a party to the link");

    BindMpcContext(m);

    BindIoClasses(m);

    BindMpcOperators(m);

    BindMpcCreateInstance(m);
}

} // namespace kcal