* Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <torch/extension.h>
#include <torch/library.h>
#include <ATen/Generator.h>
#include <ATen/Tensor.h>
#include <ATen/core/op_registration/op_registration.h>
#include "kernels_commons.h"
#include "cpu/kernels.h"
#ifdef WITH_CUDA
#include <c10/cuda/CUDAGuard.h>
#include "cuda/kernels.cuh"
#endif
using namespace at;
using namespace torch::csprng;
static const auto GENERATOR_DOES_NOT_SUPPORT_TENSOR_DEVICE_TYPE = "generator does not support tensor device type";
static const auto TENSOR_DEVICE_TYPE_IS_NOT_SUPPORTED = "tensor device type is not supported";
Tensor& random_(Tensor& self, c10::optional<Generator> gen) {
if (self.device().type() == DeviceType::CPU) {
return cpu::random_(self, gen);
#ifdef WITH_CUDA
} else if (self.device().type() == DeviceType::CUDA) {
return torch::csprng::cuda::random_(self, gen);
#endif
} else {
TORCH_CHECK(false, GENERATOR_DOES_NOT_SUPPORT_TENSOR_DEVICE_TYPE);
}
}
Tensor& random_from_to(Tensor& self, int64_t from, optional<int64_t> to,
c10::optional<Generator> gen) {
if (self.device().type() == DeviceType::CPU) {
return cpu::random_from_to(self, from, to, gen);
#ifdef WITH_CUDA
} else if (self.device().type() == DeviceType::CUDA) {
return torch::csprng::cuda::random_from_to(self, from, to, gen);
#endif
} else {
TORCH_CHECK(false, GENERATOR_DOES_NOT_SUPPORT_TENSOR_DEVICE_TYPE);
}
}
Tensor& random_to(Tensor& self, int64_t to,
c10::optional<Generator> gen) {
if (self.device().type() == DeviceType::CPU) {
return cpu::random_to(self, to, gen);
#ifdef WITH_CUDA
} else if (self.device().type() == DeviceType::CUDA) {
return torch::csprng::cuda::random_to(self, to, gen);
#endif
} else {
TORCH_CHECK(false, GENERATOR_DOES_NOT_SUPPORT_TENSOR_DEVICE_TYPE);
}
}
Tensor& uniform_(Tensor& self, double from, double to, c10::optional<Generator> gen) {
if (self.device().type() == DeviceType::CPU) {
return cpu::uniform_(self, from, to, gen);
#ifdef WITH_CUDA
} else if (self.device().type() == DeviceType::CUDA) {
return torch::csprng::cuda::uniform_(self, from, to, gen);
#endif
} else {
TORCH_CHECK(false, GENERATOR_DOES_NOT_SUPPORT_TENSOR_DEVICE_TYPE);
}
}
Tensor& normal_(Tensor& self, double mean, double std, c10::optional<Generator> gen) {
if (self.device().type() == DeviceType::CPU) {
return cpu::normal_(self, mean, std, gen);
#ifdef WITH_CUDA
} else if (self.device().type() == DeviceType::CUDA) {
return torch::csprng::cuda::normal_(self, mean, std, gen);
#endif
} else {
TORCH_CHECK(false, GENERATOR_DOES_NOT_SUPPORT_TENSOR_DEVICE_TYPE);
}
}
Tensor& normal_Tensor_float_out(const Tensor& mean, double std, c10::optional<Generator> gen, Tensor& output) {
if (output.device().type() == DeviceType::CPU) {
return cpu::normal_Tensor_float_out(output, mean, std, gen);
#ifdef WITH_CUDA
} else if (output.device().type() == DeviceType::CUDA) {
return torch::csprng::cuda::normal_Tensor_float_out(output, mean, std, gen);
#endif
} else {
TORCH_CHECK(false, GENERATOR_DOES_NOT_SUPPORT_TENSOR_DEVICE_TYPE);
}
}
Tensor& normal_float_Tensor_out(double mean, const Tensor& std, c10::optional<Generator> gen, Tensor& output) {
if (output.device().type() == DeviceType::CPU) {
return cpu::normal_float_Tensor_out(output, mean, std, gen);
#ifdef WITH_CUDA
} else if (output.device().type() == DeviceType::CUDA) {
return torch::csprng::cuda::normal_float_Tensor_out(output, mean, std, gen);
#endif
} else {
TORCH_CHECK(false, GENERATOR_DOES_NOT_SUPPORT_TENSOR_DEVICE_TYPE);
}
}
Tensor& normal_Tensor_Tensor_out(const Tensor& mean, const Tensor& std, c10::optional<Generator> gen, Tensor& output) {
if (output.device().type() == DeviceType::CPU) {
return cpu::normal_Tensor_Tensor_out(output, mean, std, gen);
#ifdef WITH_CUDA
} else if (output.device().type() == DeviceType::CUDA) {
return torch::csprng::cuda::normal_Tensor_Tensor_out(output, mean, std, gen);
#endif
} else {
TORCH_CHECK(false, GENERATOR_DOES_NOT_SUPPORT_TENSOR_DEVICE_TYPE);
}
}
Tensor normal_Tensor_float(const Tensor& mean, double std, c10::optional<Generator> gen) {
if (mean.device().type() == DeviceType::CPU) {
return cpu::normal_Tensor_float(mean, std, gen);
#ifdef WITH_CUDA
} else if (mean.device().type() == DeviceType::CUDA) {
return torch::csprng::cuda::normal_Tensor_float(mean, std, gen);
#endif
} else {
TORCH_CHECK(false, GENERATOR_DOES_NOT_SUPPORT_TENSOR_DEVICE_TYPE);
}
}
Tensor normal_float_Tensor(double mean, const Tensor& std, c10::optional<Generator> gen) {
if (std.device().type() == DeviceType::CPU) {
return cpu::normal_float_Tensor(mean, std, gen);
#ifdef WITH_CUDA
} else if (std.device().type() == DeviceType::CUDA) {
return torch::csprng::cuda::normal_float_Tensor(mean, std, gen);
#endif
} else {
TORCH_CHECK(false, GENERATOR_DOES_NOT_SUPPORT_TENSOR_DEVICE_TYPE);
}
}
Tensor normal_Tensor_Tensor(const Tensor& mean, const Tensor& std, c10::optional<Generator> gen) {
if (mean.device().type() == DeviceType::CPU) {
return cpu::normal_Tensor_Tensor(mean, std, gen);
#ifdef WITH_CUDA
} else if (mean.device().type() == DeviceType::CUDA) {
return torch::csprng::cuda::normal_Tensor_Tensor(mean, std, gen);
#endif
} else {
TORCH_CHECK(false, GENERATOR_DOES_NOT_SUPPORT_TENSOR_DEVICE_TYPE);
}
}
Tensor& cauchy_(Tensor& self, double median, double sigma, c10::optional<Generator> gen) {
if (self.device().type() == DeviceType::CPU) {
return cpu::cauchy_(self, median, sigma, gen);
#ifdef WITH_CUDA
} else if (self.device().type() == DeviceType::CUDA) {
return torch::csprng::cuda::cauchy_(self, median, sigma, gen);
#endif
} else {
TORCH_CHECK(false, GENERATOR_DOES_NOT_SUPPORT_TENSOR_DEVICE_TYPE);
}
}
Tensor& log_normal_(Tensor& self, double mean, double std, c10::optional<Generator> gen) {
if (self.device().type() == DeviceType::CPU) {
return cpu::log_normal_(self, mean, std, gen);
#ifdef WITH_CUDA
} else if (self.device().type() == DeviceType::CUDA) {
return torch::csprng::cuda::log_normal_(self, mean, std, gen);
#endif
} else {
TORCH_CHECK(false, GENERATOR_DOES_NOT_SUPPORT_TENSOR_DEVICE_TYPE);
}
}
Tensor& geometric_(Tensor& self, double p, c10::optional<Generator> gen) {
if (self.device().type() == DeviceType::CPU) {
return cpu::geometric_(self, p, gen);
#ifdef WITH_CUDA
} else if (self.device().type() == DeviceType::CUDA) {
return torch::csprng::cuda::geometric_(self, p, gen);
#endif
} else {
TORCH_CHECK(false, GENERATOR_DOES_NOT_SUPPORT_TENSOR_DEVICE_TYPE);
}
}
Tensor& exponential_(Tensor& self, double lambda, c10::optional<Generator> gen) {
if (self.device().type() == DeviceType::CPU) {
return cpu::exponential_(self, lambda, gen);
#ifdef WITH_CUDA
} else if (self.device().type() == DeviceType::CUDA) {
return torch::csprng::cuda::exponential_(self, lambda, gen);
#endif
} else {
TORCH_CHECK(false, GENERATOR_DOES_NOT_SUPPORT_TENSOR_DEVICE_TYPE);
}
}
namespace {
inline void check_supported_max_int_with_precision(int64_t n, const Tensor& tensor) {
TORCH_CHECK(at::scalar_tensor(n, tensor.options()).defined(),
"n is too large for result tensor type: '", tensor.toString(), "'");
switch (tensor.scalar_type()) {
case at::ScalarType::Half:
TORCH_CHECK(n <= (int64_t(1) << 11) + 1, "n cannot be greater than 2049 for Half type.");
break;
case at::ScalarType::Float:
TORCH_CHECK(n <= (int64_t(1) << 24) + 1, "n cannot be greater than 2^24+1 for Float type.");
break;
case at::ScalarType::Double:
TORCH_CHECK(n <= (int64_t(1) << 53) + 1, "n cannot be greater than 2^53+1 for Double type.");
break;
default:
break;
}
}
template <typename scalar_t, typename RNG>
void randperm(Tensor& result, int64_t n, c10::optional<at::Generator> generator) {
auto gen = at::check_generator<RNG>(generator);
scalar_t *r__data = result.data_ptr<scalar_t>();
result.resize_({n});
int64_t r__stride_0 = result.stride(0);
at::parallel_for(0, n, internal::GRAIN_SIZE,
[&r__data, &r__stride_0](int64_t p_begin, int64_t p_end) {
for(int64_t i = p_begin; i < p_end; i++)
r__data[i*r__stride_0] = static_cast<scalar_t>(i);
});
for(int64_t i = 0; i < n - 1; i++)
{
int64_t z = gen->random() % (n-i);
scalar_t sav = r__data[i*r__stride_0];
r__data[i*r__stride_0] = r__data[(z+i)*r__stride_0];
r__data[(z+i)*r__stride_0] = sav;
}
}
}
Tensor& randperm_generator_out(int64_t n, c10::optional<Generator> generator, Tensor& result) {
TORCH_CHECK(n >= 0, "n must be non-negative, got", n);
check_supported_max_int_with_precision(n, result);
if (result.device().type() == at::kCUDA) {
auto result_cpu = at::empty({n}, result.options().device(kCPU));
randperm_generator_out(n, generator, result_cpu);
result.resize_({n});
return result.copy_(result_cpu);
}
result.resize_({n});
std::lock_guard<std::mutex> lock(generator->mutex());
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, result.scalar_type(), "randperm", [&]() -> void {
randperm<scalar_t, CSPRNGGeneratorImpl>(result, n, generator);
});
return result;
}
Tensor encrypt_pybind(Tensor input, Tensor output, Tensor key, const std::string& cipher, const std::string& mode) {
if (input.device().type() == DeviceType::CPU) {
return cpu::encrypt(input, output, key, cipher, mode);
#ifdef WITH_CUDA
} else if (input.device().type() == DeviceType::CUDA) {
c10::cuda::CUDAGuard device_guard(input.device());
return torch::csprng::cuda::encrypt(input, output, key, cipher, mode);
#endif
} else {
TORCH_CHECK(false, TENSOR_DEVICE_TYPE_IS_NOT_SUPPORTED);
}
}
Tensor decrypt_pybind(Tensor input, Tensor output, Tensor key, const std::string& cipher, const std::string& mode) {
if (input.device().type() == DeviceType::CPU) {
return cpu::decrypt(input, output, key, cipher, mode);
#ifdef WITH_CUDA
} else if (input.device().type() == DeviceType::CUDA) {
return torch::csprng::cuda::decrypt(input, output, key, cipher, mode);
#endif
} else {
TORCH_CHECK(false, TENSOR_DEVICE_TYPE_IS_NOT_SUPPORTED);
}
}
Generator create_random_device_generator(c10::optional<std::string> token = c10::nullopt) {
if (token.has_value()) {
return make_generator<CSPRNGGeneratorImpl>(*token);
} else {
return make_generator<CSPRNGGeneratorImpl>(true);
}
}
Generator create_mt19937_generator(c10::optional<uint64_t> seed = c10::nullopt) {
if (seed.has_value()) {
return make_generator<CSPRNGGeneratorImpl>(*seed);
} else {
return make_generator<CSPRNGGeneratorImpl>(false);
}
}
bool supports_cuda() {
#ifdef WITH_CUDA
return true;
#else
return false;
#endif
}
class AES_PRG : public torch::CustomClassHolder {
private:
Tensor aes_key;
Tensor prg_seeds;
int64_t BIT_LEN;
int64_t each_gen_num;
public:
int64_t parallel_num = 0;
torch::Device device = torch::Device(torch::kCPU);
AES_PRG();
void set_seeds(Tensor seeds);
torch::Tensor bit_random(int64_t bits);
torch::Tensor random(int64_t length);
};
AES_PRG::AES_PRG(){};
void AES_PRG::set_seeds(Tensor seeds){
BIT_LEN = seeds.element_size() * 8;
each_gen_num = 128 / BIT_LEN;
parallel_num = seeds.numel() / each_gen_num;
if(seeds.size(-1) != each_gen_num) prg_seeds = seeds.view({-1, each_gen_num});
else prg_seeds = seeds;
device = seeds.device();
aes_key = tensor({1, 2}, seeds.options().dtype(torch::kInt64));
}
torch::Tensor AES_PRG::bit_random(int64_t bits)
{
std::vector<int64_t> seed_sizes = prg_seeds.sizes().vec();
const int64_t desired_num = (bits + BIT_LEN - 1) / BIT_LEN;
const int64_t desired_128_block = (bits + 127) / 128;
Tensor out_tensor = empty({parallel_num, each_gen_num * desired_128_block}, prg_seeds.options());
std::vector<int64_t> repeat_pattern(seed_sizes.size(), 1);
repeat_pattern.back() = desired_128_block;
Tensor expanded_seeds = prg_seeds.repeat(repeat_pattern);
encrypt_pybind(expanded_seeds, out_tensor, aes_key, "aes128", "custom");
aes_key += 1;
seed_sizes.back() = desired_num;
return out_tensor.slice(-1, 0, desired_num).view(seed_sizes);
}
torch::Tensor AES_PRG::random(int64_t length)
{
return bit_random(length * BIT_LEN);
}
TORCH_LIBRARY_IMPL(aten, CustomRNGKeyId, m) {
m.impl("random_.from", random_from_to);
m.impl("random_.to", random_to);
m.impl("random_", random_);
m.impl("uniform_", uniform_);
m.impl("normal_", normal_);
m.impl("normal.Tensor_float_out", normal_Tensor_float_out);
m.impl("normal.float_Tensor_out", normal_float_Tensor_out);
m.impl("normal.Tensor_Tensor_out", normal_Tensor_Tensor_out);
m.impl("normal.Tensor_float", normal_Tensor_float);
m.impl("normal.float_Tensor", normal_float_Tensor);
m.impl("normal.Tensor_Tensor", normal_Tensor_Tensor);
m.impl("cauchy_", cauchy_);
m.impl("log_normal_", log_normal_);
m.impl("geometric_", geometric_);
m.impl("exponential_", exponential_);
m.impl("randperm.generator_out", randperm_generator_out);
}
TORCH_LIBRARY(csprng_aes, m) {
m.class_<AES_PRG>("AES_PRG")
.def(torch::init<>())
.def("set_seeds", &AES_PRG::set_seeds)
.def("bit_random", &AES_PRG::bit_random)
.def("random", &AES_PRG::random)
.def_readwrite("parallel_num", &AES_PRG::parallel_num)
.def_readwrite("device", &AES_PRG::device);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("supports_cuda", &supports_cuda);
m.def("create_random_device_generator", &create_random_device_generator, py::arg("token") = nullptr);
m.def("create_mt19937_generator", &create_mt19937_generator, py::arg("seed") = nullptr);
m.def("encrypt", &encrypt_pybind);
m.def("decrypt", &decrypt_pybind);
py::class_<AES_PRG, std::shared_ptr<AES_PRG>>(m, "PRG")
.def(py::init<>())
.def("set_seeds", &AES_PRG::set_seeds)
.def("bit_random", &AES_PRG::bit_random)
.def("random", &AES_PRG::random)
.def_readonly("device", &AES_PRG::device)
.def_readonly("parallel_num", &AES_PRG::parallel_num);
}