#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <cstdio>
#include <iostream>
#include <memory>
#include <string>
#include "context_ext.h"
#include "kcal/core/mpc_operator_base.h"
#include "kcal/core/operator_factory.h"
#include "kcal/operator/kcal_pir.h"
#include "kcal/operator/kcal_psi.h"
#include "kcal/operator/kcal_psi_ub.h"
#include "kcal/utils/io.h"
namespace brpc {
DECLARE_uint64(max_body_size);
DECLARE_int64(socket_max_unwritten_bytes);
}
namespace py = pybind11;
namespace kcal {
namespace {
constexpr int MAX_UNWRITTEN_DIV = 2;
inline KCAL_AlgorithmsType IntToAlgorithmsType(int v) { return static_cast<KCAL_AlgorithmsType>(v); }
inline DG_TeeMode IntToTeeMode(int v) { return static_cast<DG_TeeMode>(v); }
inline DG_DummyMode IntToDummyMode(int v) { return static_cast<DG_DummyMode>(v); }
inline DG_ShareType IntToShareType(int v) { return static_cast<DG_ShareType>(v); }
void FeedKcalInput(const py::list &pyList, io::Input *kcalInput)
{
if (pyList.empty()) {
return;
}
const auto &itemTemp = pyList[0];
if (py::isinstance<py::str>(itemTemp)) {
auto *dgString = new (std::nothrow) DG_String[pyList.size()];
if (!dgString) {
throw std::bad_alloc();
}
for (size_t i = 0; i < pyList.size(); ++i) {
if (!PyUnicode_Check(pyList[i].ptr())) {
throw std::runtime_error("need str");
}
Py_ssize_t sz;
const char *utf8 = PyUnicode_AsUTF8AndSize(pyList[i].ptr(), &sz);
if (!utf8) {
throw std::bad_alloc();
}
dgString[i].str = new char[sz + 1];
std::memcpy(dgString[i].str, utf8, sz + 1);
dgString[i].size = static_cast<int>(sz) + 1;
}
DG_TeeInput **internalInput = kcalInput->GetSecondaryPointer();
(*internalInput)->data.strings = dgString;
(*internalInput)->size = pyList.size();
(*internalInput)->dataType = MPC_STRING;
} else {
auto inData = std::make_unique<double[]>(pyList.size());
for (size_t i = 0; i < pyList.size(); ++i) {
if (py::isinstance<py::int_>(pyList[i]) || py::isinstance<py::float_>(pyList[i])) {
inData[i] = pyList[i].cast<double>();
} else {
throw std::runtime_error("need number type");
}
}
DG_TeeInput **internalInput = kcalInput->GetSecondaryPointer();
(*internalInput)->data.doubleNumbers = inData.release();
(*internalInput)->size = pyList.size();
(*internalInput)->dataType = MPC_DOUBLE;
}
}
void FeedKcalInputFile(const py::list &pyList, io::Input *kcalInput)
{
if (pyList.empty()) {
return;
}
const auto &itemTemp = pyList[0];
if (py::isinstance<py::str>(itemTemp)) {
auto *dgString = new (std::nothrow) DG_String[pyList.size()];
if (!dgString) {
throw std::bad_alloc();
}
for (size_t i = 0; i < pyList.size(); ++i) {
if (!PyUnicode_Check(pyList[i].ptr())) {
throw std::runtime_error("need str");
}
Py_ssize_t sz;
const char *utf8 = PyUnicode_AsUTF8AndSize(pyList[i].ptr(), &sz);
if (!utf8) {
throw std::bad_alloc();
}
dgString[i].str = new char[sz + 1];
std::memcpy(dgString[i].str, utf8, sz + 1);
dgString[i].size = static_cast<int>(sz);
}
DG_TeeInput **internalInput = kcalInput->GetSecondaryPointer();
(*internalInput)->data.strings = dgString;
(*internalInput)->size = pyList.size();
(*internalInput)->dataType = MPC_STRING;
} else {
auto inData = std::make_unique<double[]>(pyList.size());
for (size_t i = 0; i < pyList.size(); ++i) {
if (py::isinstance<py::int_>(pyList[i]) || py::isinstance<py::float_>(pyList[i])) {
inData[i] = pyList[i].cast<double>();
} else {
throw std::runtime_error("need number type");
}
}
DG_TeeInput **internalInput = kcalInput->GetSecondaryPointer();
(*internalInput)->data.doubleNumbers = inData.release();
(*internalInput)->size = pyList.size();
(*internalInput)->dataType = MPC_DOUBLE;
}
}
void FeedKcalPairList(const py::list &key, const py::list &value, io::KcalPairList *pairList)
{
if (key.size() != value.size()) {
throw std::runtime_error("key value size don't match");
};
auto size = key.size();
pairList->Get()->dgPair = new (std::nothrow) DG_Pair[size];
if (!pairList->Get()->dgPair) {
throw std::bad_alloc();
}
size_t i = 0;
for (i = 0; i < key.size(); ++i) {
pairList->Get()->dgPair[i].key = new (std::nothrow) DG_String();
pairList->Get()->dgPair[i].value = new (std::nothrow) DG_String();
if (!pairList->Get()->dgPair[i].key || !pairList->Get()->dgPair[i].value) {
pairList->Get()->size = i + 1;
throw std::bad_alloc();
}
{
if (!PyUnicode_Check(key[i].ptr())) {
throw std::runtime_error("need str");
}
Py_ssize_t sz;
const char *utf8 = PyUnicode_AsUTF8AndSize(key[i].ptr(), &sz);
if (!utf8) {
throw std::bad_alloc();
}
pairList->Get()->dgPair[i].key->str = new char[sz + 1];
std::memcpy(pairList->Get()->dgPair[i].key->str, utf8, sz + 1);
pairList->Get()->dgPair[i].key->size = static_cast<int>(sz) + 1;
}
{
if (!PyUnicode_Check(value[i].ptr())) {
throw std::runtime_error("need str");
}
Py_ssize_t sz;
const char *utf8 = PyUnicode_AsUTF8AndSize(value[i].ptr(), &sz);
if (!utf8) {
throw std::bad_alloc();
}
pairList->Get()->dgPair[i].value->str = new char[sz + 1];
std::memcpy(pairList->Get()->dgPair[i].value->str, utf8, sz + 1);
pairList->Get()->dgPair[i].value->size = static_cast<int>(sz) + 1;
}
}
pairList->Get()->size = size;
}
void FeedPsiOutput(io::Output &kcalOutput, py::list &pyList, DG_TeeMode mode)
{
auto *outPtr = kcalOutput.Get();
for (size_t i = 0; i < outPtr->size; ++i) {
if (mode == TEE_OUTPUT_INDEX) {
pyList.append(outPtr->data.u64Numbers[i]);
} else if (mode == TEE_OUTPUT_STRING) {
pyList.append(outPtr->data.strings[i].str);
}
}
}
void FeedKcalOutput(io::Output &kcalOutput, py::list &pyList)
{
auto *outPtr = kcalOutput.Get();
auto dataType = kcalOutput.Get()->dataType;
for (size_t i = 0; i < outPtr->size; ++i) {
if (dataType == MPC_STRING) {
pyList.append(outPtr->data.strings[i].str);
} else if (dataType == MPC_INT) {
pyList.append(outPtr->data.u64Numbers[i]);
} else if (dataType == MPC_DOUBLE) {
pyList.append(outPtr->data.doubleNumbers[i]);
}
}
}
}
class PyCallbackAdapter {
public:
static int PySendCallback(const TeeNodeInfo &nodeInfo, const uint8_t *data, size_t dataLen,
const py::function &pySendFunc)
{
if (!data) {
return 0;
}
try {
py::dict nodeInfoDict;
nodeInfoDict["nodeId"] = nodeInfo.nodeId;
py::memoryview dataMview = py::memoryview::from_buffer(
const_cast<uint8_t *>(data), {static_cast<py::ssize_t>(dataLen)}, {sizeof(uint8_t)});
py::object result = pySendFunc(nodeInfoDict, dataMview);
return result.cast<int>();
} catch (const py::error_already_set &e) {
py::print("Python send callback error:", e.what());
return -1;
} catch (const std::exception &e) {
py::print("Send callback error:", e.what());
return -1;
}
}
static int PyRecvCallback(const TeeNodeInfo &nodeInfo, uint8_t *buffer, size_t maxLen,
const py::function &pyRecvFunc)
{
if (!buffer) {
return 0;
}
try {
py::dict nodeInfoDict;
nodeInfoDict["nodeId"] = nodeInfo.nodeId;
py::memoryview bufferMview =
py::memoryview::from_buffer(buffer, {static_cast<py::ssize_t>(maxLen)}, {sizeof(uint8_t)}, false);
py::object result = pyRecvFunc(nodeInfoDict, bufferMview);
if (result.is_none()) {
return -1;
}
return result.cast<int>();
} catch (const py::error_already_set &e) {
py::print("Python recv callback error:", e.what());
return -1;
} catch (const std::exception &e) {
py::print("Recv callback error:", e.what());
return -1;
}
}
};
void BindIoClasses(py::module_ &m)
{
py::class_<io::MpcShare>(m, "MpcShare")
.def(py::init<>())
.def("size", &io::MpcShare::Size)
.def("type", &io::MpcShare::Type);
py::class_<io::MpcShareSet>(m, "MpcShareSet")
.def(py::init<>())
.def_static(
"Create", [](const std::vector<io::MpcShare *> &shares) { return io::MpcShareSet::Create(shares); },
py::return_value_policy::take_ownership)
.def(
"Get", [](io::MpcShareSet &self) -> DG_MpcShareSet * { return self.Get(); },
py::return_value_policy::reference);
py::class_<io::Input>(m, "Input")
.def(py::init<>())
.def(py::init<DG_TeeInput *>())
.def(
"create",
[] {
auto teeInput = std::make_unique<DG_TeeInput>();
auto input = std::make_unique<io::Input>(teeInput.release());
return input;
},
py::return_value_policy::take_ownership)
.def("Set", &io::Input::Set)
.def("Fill", &io::Input::Fill)
.def("Size", &io::Input::Size);
m.attr("Output") = m.attr("Input");
}
void BindOtherOperators(py::module_ &m)
{
py::class_<Psi, std::shared_ptr<Psi>>(m, "Psi")
.def(py::init<std::shared_ptr<Context>>())
.def("run", [](Psi &self, const py::list &input, py::list &output, int mode) -> int {
DG_TeeMode teeMode = IntToTeeMode(mode);
io::Input kcalInput(new DG_TeeInput());
FeedKcalInput(input, &kcalInput);
io::Output kcalOutput;
int ret = self.Run(kcalInput, kcalOutput, teeMode);
FeedPsiOutput(kcalOutput, output, teeMode);
return ret;
});
py::class_<Pir, std::shared_ptr<Pir>>(m, "Pir")
.def(py::init<std::shared_ptr<Context>>())
.def("ServerPreProcess",
[](Pir &self, const py::list &key, py::list &value) -> int {
std::unique_ptr<io::KcalPairList> kcalInput(io::KcalPairList::Create());
FeedKcalPairList(key, value, kcalInput.get());
return self.ServerPreProcess(kcalInput->Get());
})
.def("ClientQuery", [](Pir &self, const py::list &input, py::list &output, int mode) -> int {
DG_DummyMode dummyMode = IntToDummyMode(mode);
io::Input kcalInput(new DG_TeeInput());
FeedKcalInput(input, &kcalInput);
io::Output kcalOutput;
int ret = self.ClientQuery(kcalInput, kcalOutput, dummyMode);
FeedKcalOutput(kcalOutput, output);
return ret;
})
.def("ClientQueryByFile", [](Pir &self, const py::list &input, py::list &output, int mode) -> int {
DG_DummyMode dummyMode = IntToDummyMode(mode);
io::Input kcalInput(new DG_TeeInput());
FeedKcalInputFile(input, &kcalInput);
io::Output kcalOutput;
int ret = self.ClientQuery(kcalInput, kcalOutput, dummyMode);
FeedKcalOutput(kcalOutput, output);
return ret;
})
.def("ServerAnswer", [](Pir &self) -> int { return self.ServerAnswer(); })
.def("ServerPreProcess", [](Pir &self, const std::string &input, const std::string &outputPath) -> int {
return self.ServerPreProcess(input, outputPath);
}, py::arg("input"), py::arg("outputPath"))
.def("ServerAnswer", [](Pir &self, const std::string &dataPath, int isDeleteCache) -> int {
return self.ServerAnswer(dataPath, isDeleteCache);
}, py::arg("dataPath"), py::arg("isDeleteCache"));
}
void BindMpcOperators(py::module_ &m)
{
py::class_<MpcOperatorBase, std::shared_ptr<MpcOperatorBase>>(m, "OperatorBase")
.def("GetType", &MpcOperatorBase::GetType)
.def("run",
[](MpcOperatorBase &self, const std::vector<io::MpcShare *> &shares, io::MpcShare *outShare) -> int {
auto shareSetPtr = io::MpcShareSet::Create(shares);
return self.Run(shareSetPtr, outShare);
})
.def("run", [](MpcOperatorBase &self,
const std::vector<std::string> &inputFilePaths,
std::string &outputFilePath) -> py::tuple {
int outputCount = 0;
int ret = self.Run(inputFilePaths, outputFilePath, outputCount);
return py::make_tuple(ret, outputCount);
}, py::arg("inputFilePaths"), py::arg("outputFilePath"));
py::class_<MakeShare, std::shared_ptr<MakeShare>>(m, "MakeShare")
.def(py::init<std::shared_ptr<Context>>())
.def("run", [](MakeShare &self, const py::list &input, int isRecvShare, io::MpcShare &share) {
io::Input kcalInput(new DG_TeeInput());
FeedKcalInput(input, &kcalInput);
return self.Run(kcalInput, isRecvShare, &share);
})
.def("run", [](MakeShare &self,
const std::string &inputFilePath,
int isRecvShare,
const std::string &shareFilePath) -> py::tuple {
int shareCount = 0;
int ret = self.Run(inputFilePath, isRecvShare, shareFilePath, shareCount);
return py::make_tuple(ret, shareCount);
}, py::arg("inputFilePath"), py::arg("isRecvShare"), py::arg("shareFilePath"));
py::class_<RevealShare, std::shared_ptr<RevealShare>>(m, "RevealShare")
.def(py::init<std::shared_ptr<Context>>())
.def("run", [](RevealShare &self, const io::MpcShare &share, py::list &output) {
io::Output out;
int ret = self.Run(&share, out);
FeedKcalOutput(out, output);
return ret;
})
.def("run", [](RevealShare &self,
const std::string &shareFilePath,
const std::string &outputFilePath) -> py::tuple {
int outputCount = 0;
int ret = self.Run(shareFilePath, outputFilePath, outputCount);
return py::make_tuple(ret, outputCount);
}, py::arg("shareFilePath"), py::arg("outputFilePath"));
}
void BindMpcContext(py::module_ &m)
{
py::class_<ContextExt, std::shared_ptr<ContextExt>>(m, "Context")
.def(py::init<>())
.def_static("create", [](Config config, py::function sendCb, py::function recvCb) {
auto cppSendCb = [sendCb](const TeeNodeInfo &nodeInfo, const uint8_t *data, size_t dataLen) {
return PyCallbackAdapter::PySendCallback(nodeInfo, data, dataLen, sendCb);
};
auto cppRecvCb = [recvCb](const TeeNodeInfo &nodeInfo, uint8_t *buffer, size_t maxLen) {
return PyCallbackAdapter::PyRecvCallback(nodeInfo, buffer, maxLen, recvCb);
};
return ContextExt::Create(config, cppSendCb, cppRecvCb);
})
.def_static("create_with_yacl", [](Config config, const yacl::link::ContextDesc &desc, size_t rank) {
auto linker = YaclLinker::Create(config, desc, rank);
if (!linker) {
throw std::runtime_error("Failed to create YaclLinker");
}
return ContextExt::CreateFromYaclLinker(config, linker);
})
.def_static("create_with_link_config",
[](Config config, const yacl::link::ContextDesc &desc, size_t rank, bool log_details = false) {
auto linker = YaclLinker::Create(config, desc, rank);
if (!linker) {
throw std::runtime_error("Failed to create YaclLinker");
}
auto yaclCtx = linker->GetYaclContext();
if (!yaclCtx) {
throw std::runtime_error("Failed to get yacl context");
}
{
py::gil_scoped_release release;
std::cout << "[create_with_link_config] Calling ConnectToMesh..." << std::endl;
brpc::FLAGS_max_body_size = std::numeric_limits<uint64_t>::max();
brpc::FLAGS_socket_max_unwritten_bytes = std::numeric_limits<int64_t>::max() / MAX_UNWRITTEN_DIV;
if (log_details) {
yaclCtx->ConnectToMesh(spdlog::level::info);
} else {
yaclCtx->ConnectToMesh(spdlog::level::debug);
}
std::cout << "[create_with_link_config] ConnectToMesh done" << std::endl;
}
return ContextExt::CreateFromYaclLinker(config, linker);
}, py::arg("config"), py::arg("desc"), py::arg("rank"), py::arg("log_details") = false)
.def("get_yacl_context", [](const std::shared_ptr<ContextExt> &self) { return self->GetYaclContext(); },
py::return_value_policy::reference);
}
void BindMpcCreateInstance(py::module_ &m)
{
m.def("create_psi", [](const std::shared_ptr<ContextExt> &context) -> std::shared_ptr<Psi> {
return OperatorFactory::CreatePsi(context->GetKcalContext());
});
m.def("create_pir", [](const std::shared_ptr<ContextExt> &context) -> std::shared_ptr<Pir> {
return OperatorFactory::CreatePir(context->GetKcalContext());
});
BindOtherOperators(m);
m.def("create_make_share", [](const std::shared_ptr<ContextExt> &context) -> std::shared_ptr<MakeShare> {
return OperatorFactory::CreateMakeShare(context->GetKcalContext());
});
m.def("create_reveal_share", [](const std::shared_ptr<ContextExt> &context) -> std::shared_ptr<RevealShare> {
return OperatorFactory::CreateRevealShare(context->GetKcalContext());
});
py::class_<PsiUb, std::shared_ptr<PsiUb>>(m, "PsiUb")
.def(py::init<std::shared_ptr<Context>>())
.def("run", [](PsiUb &self, const std::string &inputFilePath, const std::string &outputFilePath) -> py::tuple {
int outputCount = 0;
int ret = self.Run(inputFilePath, outputFilePath, outputCount);
return py::make_tuple(ret, outputCount);
}, py::arg("inputFilePath"), py::arg("outputFilePath"));
m.def("create_psi_ub", [](const std::shared_ptr<ContextExt> &context) -> std::shared_ptr<PsiUb> {
return OperatorFactory::CreatePsiUb(context->GetKcalContext());
});
m.def("create_mpc", [](const std::shared_ptr<ContextExt> &context, int type) -> std::shared_ptr<MpcOperatorBase> {
KCAL_AlgorithmsType algType = IntToAlgorithmsType(type);
return OperatorFactory::CreateMpc(context->GetKcalContext(), algType);
});
}
PYBIND11_MODULE(libkcal, m)
{
m.doc() = "KCAL Python bindings.";
py::class_<Config>(m, "Config")
.def(py::init<>())
.def_readwrite("nodeId", &Config::nodeId)
.def_readwrite("fixBits", &Config::fixBits)
.def_readwrite("threadCount", &Config::threadCount)
.def_readwrite("worldSize", &Config::worldSize)
.def_readwrite("useSMAlg", &Config::useSMAlg)
.def_readwrite("chunkSize", &Config::chunkSize)
.def_readwrite("bucketCount", &Config::bucketCount)
.def_readwrite("tmpPath", &Config::tmpPath);
py::class_<yacl::link::ContextDesc::Party>(m, "Party")
.def(py::init<>())
.def(py::init<const std::string &, const std::string &>())
.def_readonly("id", &yacl::link::ContextDesc::Party::id)
.def_readonly("host", &yacl::link::ContextDesc::Party::host);
py::class_<yacl::link::ContextDesc>(m, "LinkDesc")
.def(py::init<>())
.def_readwrite("id", &yacl::link::ContextDesc::id)
.def_readonly("parties", &yacl::link::ContextDesc::parties)
.def_readwrite("connect_retry_times", &yacl::link::ContextDesc::connect_retry_times)
.def_readwrite("connect_retry_interval_ms", &yacl::link::ContextDesc::connect_retry_interval_ms)
.def_readwrite("recv_timeout_ms", &yacl::link::ContextDesc::recv_timeout_ms)
.def_readwrite("http_max_payload_size", &yacl::link::ContextDesc::http_max_payload_size)
.def_readwrite("http_timeout_ms", &yacl::link::ContextDesc::http_timeout_ms)
.def_readwrite("throttle_window_size", &yacl::link::ContextDesc::throttle_window_size)
.def_readwrite("link_type", &yacl::link::ContextDesc::link_type)
.def(
"add_party",
[](yacl::link::ContextDesc &desc, std::string id, std::string host) {
desc.parties.push_back({std::move(id), std::move(host)});
},
"add a party to the link");
BindMpcContext(m);
BindIoClasses(m);
BindMpcOperators(m);
BindMpcCreateInstance(m);
}
}