#include <Python.h>
#include <ATen/core/Generator.h>
#include <ATen/core/Tensor.h>
#include <torch/library.h>
#include "randn_simt.h"
extern "C" {
PyObject* PyInit__C(void) {
static struct PyModuleDef module_def = {
PyModuleDef_HEAD_INIT,
"_C",
nullptr,
-1,
nullptr,
};
return PyModule_Create(&module_def);
}
}
namespace simt_randn {
at::Tensor& wrapper_normal_(
at::Tensor& self,
double mean,
double std,
std::optional<at::Generator> generator) {
return at::native::simt_randn::normal_privateuse1_(
self, mean, std, std::move(generator));
}
} // namespace simt_randn
TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) {
m.impl("normal_", &simt_randn::wrapper_normal_);
}