#include "op_plugin/AclOpsInterface.h"
#include "op_plugin/utils/OpAdapter.h"
namespace acl_op {
using npu_preparation = at_npu::native::OpPreparation;
using npu_utils = at_npu::native::NpuUtils;
namespace {
at::Tensor& scatter_add_out_npu_nocheck(
at::Tensor& result,
const at::Tensor& self,
int64_t dim,
const at::Tensor& index,
const at::Tensor& src) {
std::string reduction = "add";
at_npu::native::OpCommand cmd;
cmd.Name("ScatterElements")
.Input(self)
.Input(index)
.Input(src)
.Output(result)
.Attr("axis", dim)
.Attr("reduction", reduction)
.Run();
return result;
}
}
at::Tensor scatter_add(
const at::Tensor& self,
int64_t dim,
const at::Tensor& index,
const at::Tensor& src) {
return self.clone(at::MemoryFormat::Contiguous).scatter_add_(dim, index, src);
}
at::Tensor& scatter_add_(
at::Tensor& self,
int64_t dim,
const at::Tensor& index,
const at::Tensor& src) {
npu_preparation::CheckMemory({self, index, src}, {self});
if (!npu_utils::check_match(&self)) {
at::Tensor contiguous_self = npu_utils::format_contiguous(self);
scatter_add_out_npu_nocheck(contiguous_self, self, dim, index, src);
npu_utils::format_fresh_view(self, contiguous_self);
} else {
scatter_add_out_npu_nocheck(self, self, dim, index, src);
}
return self;
}
at::Tensor scatter_add(
const at::Tensor& self,
at::Dimname dim,
const at::Tensor& index,
const at::Tensor& src) {
return acl_op::scatter_add(self, dimname_to_position(self, dim), index, src);
}
}