ec2ff2f0创建于 2025年8月29日历史提交
#include "torch_npu/csrc/afd/Init.h"

#include <torch/csrc/python_headers.h>

#include <torch/csrc/Exceptions.h>

#include <torch/csrc/utils/object_ptr.h>

#include <torch/csrc/utils/pybind.h>

#include "torch_npu/csrc/afd/ScheduleContext.h"



namespace torch_npu {

namespace afd {



PyObject *afd_init(PyObject * _unused, PyObject * noargs)

{

    auto torch_npu_C_module = THPObjectPtr(PyImport_ImportModule("torch_npu._C"));

    if (!torch_npu_C_module) {

        throw python_error();

    }

    auto torch_npu_C_m = py::handle(torch_npu_C_module).cast<py::module>();



    auto m = torch_npu_C_m.def_submodule("_afd", "Attention-FFN Disaggregation");

    auto module = py::handle(m).cast<py::module>();



    py::class_<ScheduleContextHolder>(module, "ScheduleContextHolder")

        .def(py::init<int32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint64_t, uint64_t,

            uint64_t, uint64_t>())

        .def("init", &ScheduleContextHolder::Init)

        .def("get_context_tensor", &ScheduleContextHolder::GetContextTensor)

        .def("stop_schedule", &ScheduleContextHolder::StopSchedule)

        .def("get_schedule_context_info", &ScheduleContextHolder::GetScheduleContextInfo);

    Py_RETURN_TRUE;

}



// methods on torch._C

PyMethodDef methods[] = {

    {"_afd_init", afd_init, METH_NOARGS, nullptr},

    {nullptr,     nullptr, 0,            nullptr}

};



PyMethodDef *python_functions()

{

    return methods;

}



} // namespace afd

} // namespace torch_npu