#include <Python.h>

#include <ATen/core/Generator.h>
#include <ATen/core/Tensor.h>
#include <torch/library.h>

#include "randn_simt.h"

extern "C" {
PyObject* PyInit__C(void) {
  static struct PyModuleDef module_def = {
      PyModuleDef_HEAD_INIT,
      "_C",
      nullptr,
      -1,
      nullptr,
  };
  return PyModule_Create(&module_def);
}
}

namespace simt_randn {

at::Tensor& wrapper_normal_(
    at::Tensor& self,
    double mean,
    double std,
    std::optional<at::Generator> generator) {
  return at::native::simt_randn::normal_privateuse1_(
      self, mean, std, std::move(generator));
}

} // namespace simt_randn

TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) {
  m.impl("normal_", &simt_randn::wrapper_normal_);
}