import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.cpp_extension import load_inline
cpp_source = """
#include <torch/extension.h>
torch::Tensor ReLU_kernel(torch::Tensor x) {
if (!x.is_contiguous()) {
x = x.contiguous();
}
torch::ScalarType dtype = x.scalar_type();
bool need_convert = (dtype != torch::kFloat32 && dtype != torch::kInt32 && dtype != torch::kInt64);
torch::Tensor input = need_convert ? x.to(torch::kFloat32) : x;
torch::Tensor output = torch::zeros_like(input);
if (input.scalar_type() == torch::kFloat32) {
auto x_ptr = input.data_ptr<float>();
auto out_ptr = output.data_ptr<float>();
int64_t numel = input.numel();
for (int64_t i = 0; i < numel; ++i) {
out_ptr[i] = std::max(0.0f, x_ptr[i]);
}
} else if (input.scalar_type() == torch::kInt32) {
auto x_ptr = input.data_ptr<int32_t>();
auto out_ptr = output.data_ptr<int32_t>();
int64_t numel = input.numel();
for (int64_t i = 0; i < numel; ++i) {
out_ptr[i] = std::max(0, x_ptr[i]);
}
} else if (input.scalar_type() == torch::kInt64) {
auto x_ptr = input.data_ptr<int64_t>();
auto out_ptr = output.data_ptr<int64_t>();
int64_t numel = input.numel();
for (int64_t i = 0; i < numel; ++i) {
out_ptr[i] = std::max(static_cast<int64_t>(0), x_ptr[i]);
}
}
if (need_convert) {
output = output.to(dtype);
}
return output;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("ReLU_kernel", &ReLU_kernel, "ReLU C++ kernel");
}
"""
ReLU_module = load_inline(
name="ReLU",
cpp_sources=cpp_source,
extra_cflags=["-O3"],
verbose=True
)
class ModelNew(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x: torch.Tensor) -> torch.Tensor:
if x.device.type != "cpu":
x = x.cpu()
return ReLU_module.ReLU_kernel(x)