#include "op_plugin/AclOpsInterface.h"
#include "op_plugin/OpApiInterface.h"
#include "op_plugin/utils/op_api_common.h"
namespace op_api {
using npu_preparation = at_npu::native::OpPreparation;
at::Tensor scatter_add(
const at::Tensor& self,
int64_t dim,
const at::Tensor& index,
const at::Tensor& src)
{
DO_COMPATIBILITY(aclnnScatterAdd, acl_op::scatter_add(self, dim, index, src));
auto selfClone = self.clone(at::MemoryFormat::Contiguous);
npu_preparation::CheckMemory({selfClone, index, src}, {selfClone});
EXEC_NPU_CMD(aclnnScatterAdd, selfClone, dim, index, src, selfClone);
return selfClone;
}
at::Tensor& scatter_add_(
at::Tensor& self,
int64_t dim,
const at::Tensor& index,
const at::Tensor& src)
{
DO_COMPATIBILITY(aclnnScatterAdd, acl_op::scatter_add_(self, dim, index, src));
npu_preparation::CheckMemory({self, index, src}, {self});
EXEC_NPU_CMD(aclnnScatterAdd, self, dim, index, src, self);
return self;
}
at::Tensor scatter_add(
const at::Tensor& self,
at::Dimname dim,
const at::Tensor& index,
const at::Tensor& src)
{
DO_COMPATIBILITY(aclnnScatterAdd, acl_op::scatter_add(self, dim, index, src));
return op_api::scatter_add(self, dimname_to_position(self, dim), index, src);
}
}