#include "op_plugin/AclOpsInterface.h"
#include "op_plugin/OpApiInterface.h"
#include "op_plugin/utils/op_api_common.h"
namespace op_api {
at::Tensor& index_add_out(
const at::Tensor& self,
int64_t dim,
const at::Tensor& index,
const at::Tensor& source,
const at::Scalar& alpha,
at::Tensor& result) {
DO_COMPATIBILITY(aclnnIndexAdd, acl_op::index_add_out(self, dim, index, source, alpha, result));
auto self_sizes = self.sizes().vec();
auto source_sizes = source.sizes().vec();
if (source.dim() != 0 && self.dim() != 0) {
auto wrapped_dim = at::maybe_wrap_dim(dim, self.dim());
self_sizes.erase(self_sizes.begin() + wrapped_dim);
source_sizes.erase(source_sizes.begin() + wrapped_dim);
}
TORCH_CHECK(self_sizes == source_sizes,
"source tensor shape must match self tensor shape, excluding the specified dimension. Got self.shape = ",
self.sizes(),
" source.shape = ",
source.sizes(),
OPS_ERROR(ErrCode::PARAM));
at_npu::native::OpPreparation::check_tensor({self, index, source},
result,
result.scalar_type(),
self.sizes());
if (!result.is_same(self)) {
result.copy_(self);
}
EXEC_NPU_CMD(aclnnIndexAdd, result, dim, index, source, alpha, result);
return result;
}
at::Tensor index_add(
const at::Tensor& self,
int64_t dim,
const at::Tensor& index,
const at::Tensor& source,
const at::Scalar& alpha) {
DO_COMPATIBILITY(aclnnIndexAdd, acl_op::index_add(self, dim, index, source, alpha));
auto self_sizes = self.sizes().vec();
auto source_sizes = source.sizes().vec();
if (source.dim() != 0 && self.dim() != 0) {
auto wrapped_dim = at::maybe_wrap_dim(dim, self.dim());
self_sizes.erase(self_sizes.begin() + wrapped_dim);
source_sizes.erase(source_sizes.begin() + wrapped_dim);
}
TORCH_CHECK(self_sizes == source_sizes,
"source tensor shape must match self tensor shape, excluding the specified dimension. Got self.shape = ",
self.sizes(),
" source.shape = ",
source.sizes(),
OPS_ERROR(ErrCode::PARAM));
at::Tensor result = at_npu::native::OpPreparation::apply_tensor_without_format(self.sizes(), self.options());
EXEC_NPU_CMD(aclnnIndexAdd, result.copy_(self), dim, index, source, alpha, result.copy_(self));
return result;
}
at::Tensor index_add(
const at::Tensor& self,
at::Dimname dim,
const at::Tensor& index,
const at::Tensor& source,
const at::Scalar& alpha)
{
return op_api::index_add(self, dimname_to_position(self, dim), index, source, alpha);
}
}