#include <pybind11/pybind11.h>
#include <structmember.h>
#include <torch/csrc/Device.h>
#include <torch/csrc/THP.h>
#include "third_party/acl/inc/acl/acl.h"
#include "third_party/acl/inc/acl/acl_base.h"
#include "third_party/acl/inc/acl/acl_rt.h"
#include "torch_npu/csrc/core/npu/NPUGuard.h"
#include "torch_npu/csrc/core/npu/sys_ctrl/npu_sys_ctrl.h"
#include "torch_npu/csrc/npu/Module.h"
#include "torch_npu/csrc/npu/Stream.h"
PyObject *THNPStreamClass = nullptr;
static PyObject *THNPStream_pynew(
PyTypeObject *type, PyObject *args, PyObject *kwargs)
{
HANDLE_TH_ERRORS
int current_device;
NPU_CHECK_ERROR(c10_npu::GetDevice(¤t_device));
int is_sync_launch = 0;
int priority = 0;
int64_t stream_id = 0;
int64_t device_index = 0;
int64_t device_type = 0;
uint64_t stream_ptr = 0;
constexpr const char *kwlist[] = {
"priority",
"is_sync_launch",
"stream_id",
"device_index",
"device_type",
"stream_ptr",
nullptr};
if (!PyArg_ParseTupleAndKeywords(
args,
kwargs,
"|iiLLLK",
const_cast<char **>(kwlist),
&priority,
&is_sync_launch,
&stream_id,
&device_index,
&device_type,
&stream_ptr)) {
return nullptr;
}
THPObjectPtr ptr(type->tp_alloc(type, 0));
if (!ptr) {
return nullptr;
}
c10_npu::NPUStream stream =
(stream_id || device_index || device_type) ?
c10_npu::NPUStream::unpack3(
stream_id, device_index, static_cast<c10::DeviceType>(device_type)) :
(is_sync_launch ? c10_npu::getNPUStreamFromSyncLaunchPool() :
c10_npu::getStreamFromPool(priority));
THNPStream *self = (THNPStream *)ptr.get();
self->stream_id = static_cast<int64_t>(stream.id());
self->device_index = static_cast<int64_t>(stream.device_index());
self->device_type = static_cast<int64_t>(stream.device_type());
new (&self->npu_stream) c10_npu::NPUStream(stream);
return (PyObject *)ptr.release();
END_HANDLE_TH_ERRORS
}
static void THNPStream_dealloc(THNPStream *self)
{
self->npu_stream.~NPUStream();
Py_TYPE(self)->tp_free((PyObject*)self);
}
static PyObject* THNPStream_get_device(THNPStream *self, void *unused)
{
HANDLE_TH_ERRORS
return THPDevice_New(self->npu_stream.device());
END_HANDLE_TH_ERRORS
}
static PyObject* THNPStream_get_npu_stream(THNPStream *self, void *unused)
{
HANDLE_TH_ERRORS
return PyLong_FromVoidPtr(self->npu_stream.stream());
END_HANDLE_TH_ERRORS
}
static PyObject* THNPStream_get_priority(THNPStream *self, void *unused)
{
HANDLE_TH_ERRORS
TORCH_CHECK(false, "NPU dose not support Stream.get_priority() currently.", PTA_ERROR(ErrCode::NOT_SUPPORT));
END_HANDLE_TH_ERRORS
}
static PyObject* THNPStream_priority_range()
{
HANDLE_TH_ERRORS
TORCH_CHECK(false, "NPU does not support Stream.priority_range() currently.", PTA_ERROR(ErrCode::NOT_SUPPORT));
END_HANDLE_TH_ERRORS
}
static PyObject* THNPStream_query(THNPStream *self, PyObject *noargs)
{
HANDLE_TH_ERRORS
return PyBool_FromLong(self->npu_stream.query());
END_HANDLE_TH_ERRORS
}
static PyObject* THNPStream_synchronize(THNPStream *self, PyObject *noargs)
{
HANDLE_TH_ERRORS
{
pybind11::gil_scoped_release no_gil;
self->npu_stream.synchronize();
}
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
static PyObject* THNPStream_set_data_preprocess_stream(THNPStream *self, PyObject *arg)
{
HANDLE_TH_ERRORS
{
pybind11::gil_scoped_release no_gil;
bool is_data_preprocess_stream = THPUtils_unpackBool(arg);
self->npu_stream.setDataPreprocessStream(is_data_preprocess_stream);
}
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
static PyObject* THNPStream_eq(THNPStream *self, THNPStream *other)
{
HANDLE_TH_ERRORS
return PyBool_FromLong(self->npu_stream == other->npu_stream);
END_HANDLE_TH_ERRORS
}
static struct PyMemberDef THNPStream_members[] = {
{(char*)"stream_id", T_ULONGLONG, offsetof(THNPStream, stream_id), READONLY, nullptr},
{(char*)"device_type", T_ULONGLONG, offsetof(THNPStream, device_type), READONLY, nullptr},
{(char*)"device_index", T_ULONGLONG, offsetof(THNPStream, device_index), READONLY, nullptr},
{nullptr}
};
static struct PyGetSetDef THNPStream_properties[] = {
{"device", (getter)THNPStream_get_device, nullptr, nullptr, nullptr},
{"npu_stream", (getter)THNPStream_get_npu_stream, nullptr, nullptr, nullptr},
{"priority", (getter)THNPStream_get_priority, nullptr, nullptr, nullptr},
{nullptr}
};
static PyMethodDef THNPStream_methods[] = {
{(char*)"query", (PyCFunction)THNPStream_query, METH_NOARGS, nullptr},
{(char*)"synchronize", (PyCFunction)THNPStream_synchronize, METH_NOARGS, nullptr},
{(char*)"priority_range", (PyCFunction)(void(*)(void))THNPStream_priority_range, METH_STATIC | METH_NOARGS, nullptr},
{(char*)"__eq__", (PyCFunction)THNPStream_eq, METH_O, nullptr},
{(char*)"set_data_preprocess_stream", (PyCFunction)THNPStream_set_data_preprocess_stream, METH_O, nullptr},
{nullptr}
};
PyTypeObject THNPStreamType = {
PyVarObject_HEAD_INIT(nullptr, 0)
"torch_npu._C._NPUStreamBase",
sizeof(THNPStream),
0,
(destructor)THNPStream_dealloc,
0,
nullptr,
nullptr,
nullptr,
nullptr,
nullptr,
nullptr,
nullptr,
nullptr,
nullptr,
nullptr,
nullptr,
nullptr,
nullptr,
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE,
nullptr,
nullptr,
nullptr,
nullptr,
0,
nullptr,
nullptr,
THNPStream_methods,
THNPStream_members,
THNPStream_properties,
nullptr,
nullptr,
nullptr,
nullptr,
0,
nullptr,
nullptr,
THNPStream_pynew,
};
void THNPStream_init(PyObject *module)
{
Py_INCREF(THPStreamClass);
THNPStreamType.tp_base = THPStreamClass;
THNPStreamClass = (PyObject *)&THNPStreamType;
if (PyType_Ready(&THNPStreamType) < 0) {
throw python_error();
}
Py_INCREF(&THNPStreamType);
if (PyModule_AddObject(module, "_NPUStreamBase", (PyObject *)&THNPStreamType) < 0) {
throw python_error();
}
}
std::vector<c10::optional<c10_npu::NPUStream>> THNPUtils_PySequence_to_NPUStreamList(PyObject* obj)
{
if (!PySequence_Check(obj)) {
throw std::runtime_error("Expected a sequence in THNPUtils_PySequence_to_NPUStreamList" + PTA_ERROR(ErrCode::PARAM));
}
THPObjectPtr seq = THPObjectPtr(PySequence_Fast(obj, nullptr));
if (seq.get() == nullptr) {
throw std::runtime_error("expected PySequence, but got " + std::string(THPUtils_typename(obj)) + PTA_ERROR(ErrCode::PARAM));
}
std::vector<c10::optional<c10_npu::NPUStream>> streams;
Py_ssize_t length = PySequence_Fast_GET_SIZE(seq.get());
for (Py_ssize_t i = 0; i < length; i++) {
PyObject* stream = PySequence_Fast_GET_ITEM(seq.get(), i);
if (PyObject_IsInstance(stream, THNPStreamClass)) {
streams.emplace_back(c10_npu::NPUStream::unpack3(
(reinterpret_cast<THNPStream*>(stream))->stream_id,
(reinterpret_cast<THNPStream*>(stream))->device_index,
static_cast<c10::DeviceType>((reinterpret_cast<THNPStream*>(stream))->device_type)));
} else if (stream == Py_None) {
streams.emplace_back();
} else {
std::runtime_error("Unknown data type found in stream list. Need torch_npu.npu.Stream or None" + PTA_ERROR(ErrCode::TYPE));
}
}
return streams;
}
c10_npu::NPUStream THNPUtils_PyObject_to_NPUStream(PyObject* stream)
{
TORCH_CHECK(PyObject_IsInstance(stream, THNPStreamClass), "Need torch_npu.npu.Stream argument type.");
return c10_npu::NPUStream::unpack3(
(reinterpret_cast<THNPStream *>(stream))->stream_id,
(reinterpret_cast<THNPStream *>(stream))->device_index,
static_cast<c10::DeviceType>((reinterpret_cast<THNPStream *>(stream))->device_type));
}