#include "op_plugin/AclOpsInterface.h"
#include "op_plugin/OpApiInterface.h"
#include "op_plugin/utils/op_api_common.h"
namespace op_api {
using npu_preparation = at_npu::native::OpPreparation;
std::vector<at::Tensor> npu_scatter_list(
at::TensorList self,
const at::Tensor &indice,
const at::Tensor &updates,
const c10::optional<at::Tensor> &mask,
c10::string_view reduce,
int64_t axis)
{
std::string reduce_str = std::string(reduce);
char *reduce_ptr = const_cast<char *>(reduce_str.c_str());
std::vector<at::Tensor> result;
for (const at::Tensor &tensor : self) {
result.push_back(tensor.clone());
}
at::TensorList result_ = at::TensorList(result);
EXEC_NPU_CMD(aclnnScatterList, result_, indice, updates, mask, reduce_ptr, axis);
return result;
}
void npu_scatter_list_(
at::TensorList self,
const at::Tensor &indice,
const at::Tensor &updates,
const c10::optional<at::Tensor> &mask,
c10::string_view reduce,
int64_t axis)
{
std::string reduce_str = std::string(reduce);
char *reduce_ptr = const_cast<char *>(reduce_str.c_str());
EXEC_NPU_CMD(aclnnScatterList, self, indice, updates, mask, reduce_ptr, axis);
return;
}
}