#include <ATen/Tensor.h>
#include <ATen/ops/add.h>
#include <c10/core/DeviceType.h>
#include <torch/extension.h>
#include <torch_npu/csrc/core/npu/NPUFunctions.h>
#include <torch_npu/csrc/core/npu/NPUGuard.h>
#include <torch_npu/csrc/core/npu/NPUStream.h>
#include <torch_npu/csrc/inductor/aoti_torch/c/shim.h>
#include <torch_npu/csrc/inductor/aoti_torch/utils.h>
#include <mutex>
#include <optional>
#include <unordered_map>
#include <vector>
namespace {
std::mutex& stream_registry_mutex()
{
static std::mutex mutex;
return mutex;
}
std::unordered_map<void*, c10::Stream>& stream_registry()
{
static std::unordered_map<void*, c10::Stream> registry;
return registry;
}
void remember_stream(const c10_npu::NPUStream& stream)
{
auto* raw_stream = reinterpret_cast<void*>(stream.stream(false));
std::lock_guard<std::mutex> lock(stream_registry_mutex());
stream_registry().insert_or_assign(raw_stream, static_cast<c10::Stream>(stream));
}
void check_aoti_error(AOTITorchError err, const char* call)
{
TORCH_CHECK(err == AOTI_TORCH_SUCCESS, call, " failed with error code ", err);
}
std::vector<int64_t> to_vector(c10::IntArrayRef values)
{
return std::vector<int64_t>(values.begin(), values.end());
}
void assert_npu_tensor(const at::Tensor& tensor, const char* op_name)
{
TORCH_CHECK(
tensor.device().type() == c10::DeviceType::PrivateUse1,
op_name,
" expects an NPU tensor");
TORCH_CHECK(
tensor.scalar_type() == at::kFloat,
op_name,
" currently tests float32 tensors only");
TORCH_CHECK(
tensor.layout() == at::kStrided,
op_name,
" expects a strided tensor");
}
int64_t get_npu_raw_stream_impl(const at::Tensor& tensor)
{
assert_npu_tensor(tensor, "get_npu_raw_stream");
void* shim_stream = nullptr;
const auto device_index = tensor.device().index();
check_aoti_error(
aoti_torch_get_current_npu_stream(device_index, &shim_stream),
"aoti_torch_get_current_npu_stream");
auto direct_stream =
reinterpret_cast<void*>(c10_npu::getCurrentNPUStream(device_index).stream(false));
TORCH_CHECK(
shim_stream == direct_stream,
"aoti_torch_get_current_npu_stream returned ",
shim_stream,
", but c10_npu returned ",
direct_stream);
return reinterpret_cast<int64_t>(shim_stream);
}
at::Tensor make_zero_size_blob_tensor_impl(const at::Tensor& tensor)
{
assert_npu_tensor(tensor, "make_zero_size_blob_tensor");
const auto device_index = tensor.device().index();
std::vector<int64_t> sizes = {0};
std::vector<int64_t> strides = {1};
AtenTensorHandle handle = nullptr;
check_aoti_error(
aoti_torch_create_tensor_from_blob_npu(
nullptr,
static_cast<int64_t>(sizes.size()),
sizes.data(),
strides.data(),
0,
aoti_torch_dtype_float32(),
aoti_torch_device_type_npu(),
device_index,
&handle),
"aoti_torch_create_tensor_from_blob_npu(nullptr, npu)");
at::Tensor result =
*torch::aot_inductor::tensor_handle_to_tensor_pointer(handle);
TORCH_CHECK(result.device() == tensor.device(), "Zero-size NPU tensor device mismatch");
TORCH_CHECK(result.scalar_type() == at::kFloat, "Zero-size NPU tensor dtype mismatch");
TORCH_CHECK(result.layout() == at::kStrided, "Zero-size NPU tensor layout mismatch");
TORCH_CHECK(result.sizes().vec() == sizes, "Zero-size NPU tensor sizes mismatch");
TORCH_CHECK(result.strides().vec() == strides, "Zero-size NPU tensor strides mismatch");
TORCH_CHECK(result.numel() == 0, "Zero-size NPU tensor should be empty");
check_aoti_error(
aoti_torch_delete_tensor_object(handle),
"aoti_torch_delete_tensor_object(zero_size_npu)");
return result;
}
at::Tensor make_zero_size_cpu_blob_tensor_impl(const at::Tensor& tensor)
{
assert_npu_tensor(tensor, "make_zero_size_cpu_blob_tensor");
std::vector<int64_t> sizes = {0};
std::vector<int64_t> strides = {1};
AtenTensorHandle handle = nullptr;
check_aoti_error(
aoti_torch_create_tensor_from_blob_npu(
nullptr,
static_cast<int64_t>(sizes.size()),
sizes.data(),
strides.data(),
0,
aoti_torch_dtype_float32(),
aoti_torch_device_type_cpu(),
0,
&handle),
"aoti_torch_create_tensor_from_blob_npu(nullptr, cpu)");
at::Tensor result =
*torch::aot_inductor::tensor_handle_to_tensor_pointer(handle);
TORCH_CHECK(result.device().type() == c10::DeviceType::CPU, "Zero-size CPU tensor device mismatch");
TORCH_CHECK(result.scalar_type() == at::kFloat, "Zero-size CPU tensor dtype mismatch");
TORCH_CHECK(result.layout() == at::kStrided, "Zero-size CPU tensor layout mismatch");
TORCH_CHECK(result.sizes().vec() == sizes, "Zero-size CPU tensor sizes mismatch");
TORCH_CHECK(result.strides().vec() == strides, "Zero-size CPU tensor strides mismatch");
TORCH_CHECK(result.numel() == 0, "Zero-size CPU tensor should be empty");
check_aoti_error(
aoti_torch_delete_tensor_object(handle),
"aoti_torch_delete_tensor_object(zero_size_cpu)");
return result;
}
int64_t check_mkldnn_blob_tensor_v2_rejected_impl(const at::Tensor& tensor)
{
assert_npu_tensor(tensor, "check_mkldnn_blob_tensor_v2_rejected");
auto contiguous = tensor.contiguous();
auto sizes = to_vector(contiguous.sizes());
auto strides = to_vector(contiguous.strides());
AtenTensorHandle handle = nullptr;
auto err = aoti_torch_create_tensor_from_blob_npu_v2(
contiguous.data_ptr(),
contiguous.dim(),
sizes.data(),
strides.data(),
contiguous.storage_offset(),
aoti_torch_dtype_float32(),
aoti_torch_device_type_npu(),
tensor.device().index(),
&handle,
static_cast<int32_t>(at::kMkldnn),
nullptr,
0);
TORCH_CHECK(
err == AOTI_TORCH_FAILURE,
"mkldnn layout should make aoti_torch_create_tensor_from_blob_npu_v2 fail");
TORCH_CHECK(handle == nullptr, "Handle should remain nullptr when mkldnn layout is rejected");
return 1;
}
int64_t check_blob_tensor_v2_propagates_invalid_device_failure_impl(const at::Tensor& tensor)
{
assert_npu_tensor(tensor, "check_blob_tensor_v2_propagates_invalid_device_failure");
auto contiguous = tensor.contiguous();
auto sizes = to_vector(contiguous.sizes());
auto strides = to_vector(contiguous.strides());
const auto invalid_device_index = static_cast<int32_t>(c10_npu::device_count());
AtenTensorHandle handle = nullptr;
auto err = aoti_torch_create_tensor_from_blob_npu_v2(
contiguous.data_ptr(),
contiguous.dim(),
sizes.data(),
strides.data(),
contiguous.storage_offset(),
aoti_torch_dtype_float32(),
aoti_torch_device_type_npu(),
invalid_device_index,
&handle,
aoti_torch_layout_strided(),
nullptr,
0);
TORCH_CHECK(
err == AOTI_TORCH_FAILURE,
"Invalid device should make aoti_torch_create_tensor_from_blob_npu_v2 fail");
TORCH_CHECK(
handle == nullptr,
"Handle should remain nullptr when the inner blob tensor path fails");
return 1;
}
int64_t check_null_delete_paths_impl(const at::Tensor& tensor)
{
assert_npu_tensor(tensor, "check_null_delete_paths");
check_aoti_error(
aoti_torch_delete_npu_guard(nullptr),
"aoti_torch_delete_npu_guard(nullptr)");
check_aoti_error(
aoti_torch_delete_npu_stream_guard(nullptr),
"aoti_torch_delete_npu_stream_guard(nullptr)");
check_aoti_error(
aoti_torch_npu_caching_allocator_raw_delete(nullptr),
"aoti_torch_npu_caching_allocator_raw_delete(nullptr)");
return 1;
}
int64_t check_invalid_stream_guard_path_impl(const at::Tensor& tensor)
{
assert_npu_tensor(tensor, "check_invalid_stream_guard_path");
NPUStreamGuardHandle guard = nullptr;
auto err = aoti_torch_create_npu_stream_guard(
reinterpret_cast<void*>(0x1),
tensor.device().index(),
&guard);
TORCH_CHECK(
err == AOTI_TORCH_FAILURE,
"Invalid stream pointer should make aoti_torch_create_npu_stream_guard fail");
TORCH_CHECK(guard == nullptr, "Guard should remain nullptr when stream guard creation fails");
return 1;
}
int64_t check_null_stream_guard_path_impl(const at::Tensor& tensor)
{
assert_npu_tensor(tensor, "check_null_stream_guard_path");
NPUStreamGuardHandle guard = nullptr;
auto err =
aoti_torch_create_npu_stream_guard(nullptr, tensor.device().index(), &guard);
TORCH_CHECK(
err == AOTI_TORCH_FAILURE,
"Null stream pointer should make aoti_torch_create_npu_stream_guard fail");
TORCH_CHECK(guard == nullptr, "Guard should remain nullptr when null stream guard creation fails");
return 1;
}
int64_t check_invalid_device_guard_creation_impl(const at::Tensor& tensor)
{
assert_npu_tensor(tensor, "check_invalid_device_guard_creation");
const auto invalid_device_index = static_cast<int32_t>(c10_npu::device_count());
NPUGuardHandle guard = nullptr;
auto err = aoti_torch_create_npu_guard(invalid_device_index, &guard);
TORCH_CHECK(
err == AOTI_TORCH_FAILURE,
"Invalid device index should make aoti_torch_create_npu_guard fail");
TORCH_CHECK(guard == nullptr, "Guard should remain nullptr when invalid guard creation fails");
return 1;
}
int64_t check_invalid_device_guard_set_index_impl(const at::Tensor& tensor)
{
assert_npu_tensor(tensor, "check_invalid_device_guard_set_index");
const auto original_device = c10_npu::current_device();
const auto valid_device_index = tensor.device().index();
const auto invalid_device_index = static_cast<int32_t>(c10_npu::device_count());
NPUGuardHandle guard = nullptr;
check_aoti_error(
aoti_torch_create_npu_guard(valid_device_index, &guard),
"aoti_torch_create_npu_guard(valid for set_index)");
TORCH_CHECK(
c10_npu::current_device() == valid_device_index,
"Valid guard creation should switch to the requested device");
auto err = aoti_torch_npu_guard_set_index(guard, invalid_device_index);
TORCH_CHECK(
err == AOTI_TORCH_FAILURE,
"Invalid device index should make aoti_torch_npu_guard_set_index fail");
TORCH_CHECK(
c10_npu::current_device() == valid_device_index,
"Failed guard set_index should keep the previously selected device");
check_aoti_error(
aoti_torch_delete_npu_guard(guard),
"aoti_torch_delete_npu_guard(valid after failed set_index)");
TORCH_CHECK(
c10_npu::current_device() == original_device,
"Deleting the guard should restore the original device after failed set_index");
return 1;
}
int64_t check_invalid_device_current_stream_impl(const at::Tensor& tensor)
{
assert_npu_tensor(tensor, "check_invalid_device_current_stream");
const auto invalid_device_index = static_cast<int32_t>(c10_npu::device_count());
void* stream = nullptr;
auto err = aoti_torch_get_current_npu_stream(invalid_device_index, &stream);
TORCH_CHECK(
err == AOTI_TORCH_FAILURE,
"Invalid device index should make aoti_torch_get_current_npu_stream fail");
TORCH_CHECK(stream == nullptr, "Invalid device stream lookup should leave stream as nullptr");
return 1;
}
int64_t check_null_stream_guard_output_handle_impl(const at::Tensor& tensor)
{
assert_npu_tensor(tensor, "check_null_stream_guard_output_handle");
auto pooled_stream = c10_npu::getNPUStreamFromPool(tensor.device().index());
remember_stream(pooled_stream);
auto err = aoti_torch_create_npu_stream_guard(
reinterpret_cast<void*>(pooled_stream.stream(false)),
tensor.device().index(),
nullptr);
TORCH_CHECK(
err == AOTI_TORCH_FAILURE,
"Null ret_guard should make aoti_torch_create_npu_stream_guard fail");
return 1;
}
int64_t check_null_current_stream_output_handle_impl(const at::Tensor& tensor)
{
assert_npu_tensor(tensor, "check_null_current_stream_output_handle");
auto err = aoti_torch_get_current_npu_stream(tensor.device().index(), nullptr);
TORCH_CHECK(
err == AOTI_TORCH_FAILURE,
"Null ret_stream should make aoti_torch_get_current_npu_stream fail");
return 1;
}
int64_t check_null_guard_output_handle_impl(const at::Tensor& tensor)
{
assert_npu_tensor(tensor, "check_null_guard_output_handle");
auto err = aoti_torch_create_npu_guard(tensor.device().index(), nullptr);
TORCH_CHECK(
err == AOTI_TORCH_FAILURE,
"Null ret_guard should make aoti_torch_create_npu_guard fail");
return 1;
}
int64_t check_null_allocator_output_handle_impl(const at::Tensor& tensor)
{
assert_npu_tensor(tensor, "check_null_allocator_output_handle");
auto err = aoti_torch_npu_caching_allocator_raw_alloc(64, nullptr);
TORCH_CHECK(
err == AOTI_TORCH_FAILURE,
"Null ret_ptr should make aoti_torch_npu_caching_allocator_raw_alloc fail");
return 1;
}
int64_t check_null_blob_tensor_output_handle_impl(const at::Tensor& tensor)
{
assert_npu_tensor(tensor, "check_null_blob_tensor_output_handle");
auto contiguous = tensor.contiguous();
auto sizes = to_vector(contiguous.sizes());
auto strides = to_vector(contiguous.strides());
auto err = aoti_torch_create_tensor_from_blob_npu(
contiguous.data_ptr(),
contiguous.dim(),
sizes.data(),
strides.data(),
contiguous.storage_offset(),
aoti_torch_dtype_float32(),
aoti_torch_device_type_npu(),
tensor.device().index(),
nullptr);
TORCH_CHECK(
err == AOTI_TORCH_FAILURE,
"Null ret_new_tensor should make aoti_torch_create_tensor_from_blob_npu fail");
return 1;
}
int64_t check_null_blob_tensor_v2_output_handle_impl(const at::Tensor& tensor)
{
assert_npu_tensor(tensor, "check_null_blob_tensor_v2_output_handle");
auto contiguous = tensor.contiguous();
auto sizes = to_vector(contiguous.sizes());
auto strides = to_vector(contiguous.strides());
auto err = aoti_torch_create_tensor_from_blob_npu_v2(
contiguous.data_ptr(),
contiguous.dim(),
sizes.data(),
strides.data(),
contiguous.storage_offset(),
aoti_torch_dtype_float32(),
aoti_torch_device_type_npu(),
tensor.device().index(),
nullptr,
aoti_torch_layout_strided(),
nullptr,
0);
TORCH_CHECK(
err == AOTI_TORCH_FAILURE,
"Null ret_new_tensor should make aoti_torch_create_tensor_from_blob_npu_v2 fail");
return 1;
}
int64_t check_current_device_stream_lookup_impl(const at::Tensor& tensor)
{
assert_npu_tensor(tensor, "check_current_device_stream_lookup");
const auto original_device = c10_npu::current_device();
c10_npu::NPUGuard device_guard(tensor.device().index());
void* shim_stream = nullptr;
check_aoti_error(
aoti_torch_get_current_npu_stream(-1, &shim_stream),
"aoti_torch_get_current_npu_stream(current device)");
auto current_stream =
reinterpret_cast<void*>(c10_npu::getCurrentNPUStream(tensor.device().index()).stream(false));
TORCH_CHECK(
shim_stream == current_stream,
"Current-device stream fallback mismatch");
TORCH_CHECK(
c10_npu::current_device() == tensor.device().index(),
"Current-device lookup should keep the selected device");
TORCH_CHECK(
c10_npu::current_device() == tensor.device().index(),
"Device guard should keep the tensor device selected inside the scope");
device_guard.set_index(original_device);
TORCH_CHECK(
c10_npu::current_device() == original_device,
"Current-device lookup helper should restore the original device");
return 1;
}
int64_t check_default_stream_guard_roundtrip_impl(const at::Tensor& tensor)
{
assert_npu_tensor(tensor, "check_default_stream_guard_roundtrip");
const auto device_index = tensor.device().index();
auto pooled_stream = c10_npu::getNPUStreamFromPool(device_index);
remember_stream(pooled_stream);
c10_npu::NPUStreamGuard outer_guard(static_cast<c10::Stream>(pooled_stream));
auto original_stream =
reinterpret_cast<void*>(c10_npu::getCurrentNPUStream(device_index).stream(false));
auto default_stream = c10_npu::getDefaultNPUStream(device_index);
remember_stream(default_stream);
NPUStreamGuardHandle stream_guard = nullptr;
check_aoti_error(
aoti_torch_create_npu_stream_guard(
reinterpret_cast<void*>(default_stream.stream(false)),
device_index,
&stream_guard),
"aoti_torch_create_npu_stream_guard(default stream)");
auto guarded_stream =
reinterpret_cast<void*>(c10_npu::getCurrentNPUStream(device_index).stream(false));
TORCH_CHECK(
guarded_stream == reinterpret_cast<void*>(default_stream.stream(false)),
"Default stream guard did not switch to the default stream");
check_aoti_error(
aoti_torch_delete_npu_stream_guard(stream_guard),
"aoti_torch_delete_npu_stream_guard(default stream)");
auto restored_stream =
reinterpret_cast<void*>(c10_npu::getCurrentNPUStream(device_index).stream(false));
TORCH_CHECK(
restored_stream == original_stream,
"Default stream guard did not restore the original stream");
return 1;
}
at::Tensor run_npu_shim_checks_impl(const at::Tensor& tensor)
{
assert_npu_tensor(tensor, "run_npu_shim_checks");
const auto device_index = tensor.device().index();
TORCH_CHECK(
aoti_torch_device_type_npu() ==
static_cast<int32_t>(c10::DeviceType::PrivateUse1),
"aoti_torch_device_type_npu should return PrivateUse1");
void* shim_stream = nullptr;
check_aoti_error(
aoti_torch_get_current_npu_stream(device_index, &shim_stream),
"aoti_torch_get_current_npu_stream");
auto current_stream =
reinterpret_cast<void*>(c10_npu::getCurrentNPUStream(device_index).stream(false));
TORCH_CHECK(
shim_stream == current_stream,
"Current stream mismatch: shim=",
shim_stream,
", direct=",
current_stream);
const auto original_device = c10_npu::current_device();
NPUGuardHandle guard = nullptr;
check_aoti_error(
aoti_torch_create_npu_guard(device_index, &guard),
"aoti_torch_create_npu_guard");
TORCH_CHECK(
c10_npu::current_device() == device_index,
"NPU guard did not switch the current device");
check_aoti_error(
aoti_torch_npu_guard_set_index(guard, device_index),
"aoti_torch_npu_guard_set_index");
TORCH_CHECK(
c10_npu::current_device() == device_index,
"NPU guard set_index did not keep the requested device");
check_aoti_error(
aoti_torch_delete_npu_guard(guard),
"aoti_torch_delete_npu_guard");
TORCH_CHECK(
c10_npu::current_device() == original_device,
"Deleting the NPU guard did not restore the original device");
auto original_stream =
reinterpret_cast<void*>(c10_npu::getCurrentNPUStream(device_index).stream(false));
auto pooled_stream_obj = c10_npu::getNPUStreamFromPool(device_index);
remember_stream(pooled_stream_obj);
auto pooled_stream =
reinterpret_cast<void*>(pooled_stream_obj.stream(false));
NPUStreamGuardHandle stream_guard = nullptr;
check_aoti_error(
aoti_torch_create_npu_stream_guard(pooled_stream, device_index, &stream_guard),
"aoti_torch_create_npu_stream_guard");
auto guarded_stream =
reinterpret_cast<void*>(c10_npu::getCurrentNPUStream(device_index).stream(false));
TORCH_CHECK(
guarded_stream == pooled_stream,
"NPU stream guard did not switch to the pooled stream");
check_aoti_error(
aoti_torch_delete_npu_stream_guard(stream_guard),
"aoti_torch_delete_npu_stream_guard");
auto restored_stream =
reinterpret_cast<void*>(c10_npu::getCurrentNPUStream(device_index).stream(false));
TORCH_CHECK(
restored_stream == original_stream,
"Deleting the NPU stream guard did not restore the original stream");
void* zero_alloc = reinterpret_cast<void*>(0x1);
check_aoti_error(
aoti_torch_npu_caching_allocator_raw_alloc(0, &zero_alloc),
"aoti_torch_npu_caching_allocator_raw_alloc(0)");
TORCH_CHECK(
zero_alloc == nullptr,
"Zero-byte NPU allocation should return nullptr");
void* alloc_ptr = nullptr;
check_aoti_error(
aoti_torch_npu_caching_allocator_raw_alloc(64, &alloc_ptr),
"aoti_torch_npu_caching_allocator_raw_alloc");
TORCH_CHECK(alloc_ptr != nullptr, "NPU allocator returned nullptr");
check_aoti_error(
aoti_torch_npu_caching_allocator_raw_delete(alloc_ptr),
"aoti_torch_npu_caching_allocator_raw_delete");
auto contiguous = tensor.contiguous();
auto sizes = to_vector(contiguous.sizes());
auto strides = to_vector(contiguous.strides());
AtenTensorHandle alias_handle = nullptr;
check_aoti_error(
aoti_torch_create_tensor_from_blob_npu(
contiguous.data_ptr(),
contiguous.dim(),
sizes.data(),
strides.data(),
contiguous.storage_offset(),
aoti_torch_dtype_float32(),
aoti_torch_device_type_npu(),
device_index,
&alias_handle),
"aoti_torch_create_tensor_from_blob_npu");
at::Tensor alias =
*torch::aot_inductor::tensor_handle_to_tensor_pointer(alias_handle);
TORCH_CHECK(alias.device() == contiguous.device(), "Alias tensor device mismatch");
TORCH_CHECK(alias.scalar_type() == contiguous.scalar_type(), "Alias tensor dtype mismatch");
TORCH_CHECK(alias.data_ptr() == contiguous.data_ptr(), "Alias tensor data_ptr mismatch");
TORCH_CHECK(
to_vector(alias.sizes()) == sizes,
"Alias tensor sizes mismatch");
TORCH_CHECK(
to_vector(alias.strides()) == strides,
"Alias tensor strides mismatch");
TORCH_CHECK(
alias.storage_offset() == contiguous.storage_offset(),
"Alias tensor storage_offset mismatch");
TORCH_CHECK(alias.equal(contiguous), "Alias tensor value mismatch");
check_aoti_error(
aoti_torch_delete_tensor_object(alias_handle),
"aoti_torch_delete_tensor_object(alias_handle)");
auto result = at::add(alias, 1.0);
return result;
}
}
namespace c10_npu {
NPUStream getNPUStreamFromManagedAclrtStream(aclrtStream stream, c10::DeviceIndex device_index)
{
TORCH_CHECK(stream != nullptr, "stream is nullptr");
auto find_registered_stream = [stream]() -> std::optional<NPUStream> {
std::lock_guard<std::mutex> lock(stream_registry_mutex());
auto it = stream_registry().find(reinterpret_cast<void*>(stream));
if (it == stream_registry().end()) {
return std::nullopt;
}
return NPUStream(NPUStream::UNCHECKED, it->second);
};
if (auto registered = find_registered_stream()) {
return *registered;
}
auto find_known_stream = [stream](c10::DeviceIndex idx) -> std::optional<NPUStream> {
auto current = getCurrentNPUStream(idx);
if (current.stream(false) == stream) {
remember_stream(current);
return current;
}
auto default_stream = getDefaultNPUStream(idx);
if (default_stream.stream(false) == stream) {
remember_stream(default_stream);
return default_stream;
}
return std::nullopt;
};
if (device_index != -1) {
if (auto known = find_known_stream(device_index)) {
return *known;
}
} else {
const auto device_count = c10_npu::device_count();
for (c10::DeviceIndex idx = 0; idx < device_count; ++idx) {
if (auto known = find_known_stream(idx)) {
return *known;
}
}
}
TORCH_CHECK(
false,
"The aclrtStream is not managed by the shim test registry on device ",
device_index);
}
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("get_npu_raw_stream", &get_npu_raw_stream_impl);
m.def("make_zero_size_blob_tensor", &make_zero_size_blob_tensor_impl);
m.def("make_zero_size_cpu_blob_tensor", &make_zero_size_cpu_blob_tensor_impl);
m.def("check_mkldnn_blob_tensor_v2_rejected", &check_mkldnn_blob_tensor_v2_rejected_impl);
m.def(
"check_blob_tensor_v2_propagates_invalid_device_failure",
&check_blob_tensor_v2_propagates_invalid_device_failure_impl);
m.def("check_null_delete_paths", &check_null_delete_paths_impl);
m.def("check_invalid_stream_guard_path", &check_invalid_stream_guard_path_impl);
m.def("check_null_stream_guard_path", &check_null_stream_guard_path_impl);
m.def("check_invalid_device_guard_creation", &check_invalid_device_guard_creation_impl);
m.def("check_invalid_device_guard_set_index", &check_invalid_device_guard_set_index_impl);
m.def("check_invalid_device_current_stream", &check_invalid_device_current_stream_impl);
m.def("check_null_stream_guard_output_handle", &check_null_stream_guard_output_handle_impl);
m.def("check_null_current_stream_output_handle", &check_null_current_stream_output_handle_impl);
m.def("check_null_guard_output_handle", &check_null_guard_output_handle_impl);
m.def("check_null_allocator_output_handle", &check_null_allocator_output_handle_impl);
m.def("check_null_blob_tensor_output_handle", &check_null_blob_tensor_output_handle_impl);
m.def("check_null_blob_tensor_v2_output_handle", &check_null_blob_tensor_v2_output_handle_impl);
m.def("check_current_device_stream_lookup", &check_current_device_stream_lookup_impl);
m.def("check_default_stream_guard_roundtrip", &check_default_stream_guard_roundtrip_impl);
m.def("run_npu_shim_checks", &run_npu_shim_checks_impl);
}