* @file add_custom.cpp
*
* Copyright (C) 2024-2025. Huawei Technologies Co., Ltd. All rights reserved.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
*/
#include <torch/library.h>
#include <torch/csrc/autograd/custom_function.h>
#include <torch/extension.h>
#include "npu_cpp_extension.h"
using torch::autograd::Function;
using torch::autograd::AutogradContext;
using variable_list = std::vector<at::Tensor>;
at::Tensor add_custom_impl_npu(const at::Tensor& self, const at::Tensor& other)
{
at::Tensor result = at::empty_like(self);
at::Scalar alpha = 1.0;
EXEC_NPU_CMD_EXT(aclnnAdd, self, other, alpha, result);
return result;
}
std::tuple<at::Tensor, at::Tensor> add_custom_backward_impl_npu(const at::Tensor& grad)
{
at::Tensor result = grad;
return {result, result};
}
at::Tensor add_custom_impl_meta(const at::Tensor& self, const at::Tensor& other)
{
return at::empty_like(self);
}
std::tuple<at::Tensor, at::Tensor> add_custom_backward_impl_meta(const at::Tensor& self)
{
auto result = at::empty_like(self);
return std::make_tuple(result, result);
}
class AddCustomFunction : public torch::autograd::Function<AddCustomFunction> {
public:
static at::Tensor forward(AutogradContext *ctx, at::Tensor self, at::Tensor other)
{
at::AutoDispatchBelowADInplaceOrView guard;
static auto op = torch::Dispatcher::singleton()
.findSchemaOrThrow("myops::add_custom", "")
.typed<decltype(add_custom_impl_npu)>();
auto result = op.call(self, other);
return result;
}
static variable_list backward(AutogradContext *ctx, variable_list grad_outputs)
{
auto grad_output = grad_outputs[0];
static auto op = torch::Dispatcher::singleton()
.findSchemaOrThrow("myops::add_custom_backward", "")
.typed<decltype(add_custom_backward_impl_npu)>();
auto result = op.call(grad_output);
return {std::get<0>(result), std::get<1>(result)};
}
};
at::Tensor add_custom_autograd(const at::Tensor& self, const at::Tensor& other)
{
return AddCustomFunction::apply(self, other);
}
TORCH_LIBRARY_IMPL(myops, PrivateUse1, m) {
m.impl("add_custom", &add_custom_impl_npu);
m.impl("add_custom_backward", &add_custom_backward_impl_npu);
}
TORCH_LIBRARY_IMPL(myops, AutogradPrivateUse1, m) {
m.impl("add_custom", &add_custom_autograd);
}
TORCH_LIBRARY_IMPL(myops, Meta, m) {
m.impl("add_custom", &add_custom_impl_meta);
m.impl("add_custom_backward", &add_custom_backward_impl_meta);
}