#include "op_plugin/AclOpsInterface.h"
#include "op_plugin/utils/OpAdapter.h"
namespace acl_op {
using npu_preparation = at_npu::native::OpPreparation;
using npu_utils = at_npu::native::NpuUtils;
namespace {
void index_fill_d_check_index(at::IntArrayRef shape, const at::Tensor &index, int64_t dim)
{
TORCH_CHECK(index.dim() == 1, "Index should be a one-dimensional tensor"
+ OPS_ERROR(ErrCode::PARAM));
int index_temp = INT_MAX;
for (int i = 0; i < index.sizes()[0]; i++) {
index_temp = static_cast<int>(op_plugin::utils::get_scalar_float_value(index[i].item()));
TORCH_CHECK(shape[dim] > index_temp, "Index out of range, it should be in [0,", shape[dim], ")"
+ OPS_ERROR(ErrCode::VALUE));
}
}
c10::SmallVector<float, N> index_fill_d_assist_help_init(int64_t dim, at::IntArrayRef sizes, std::vector<int> index,
bool flag, float value)
{
int blocksize = 0;
int blocknum = 1;
int n = 1;
for (uint i = 0; i < sizes.size(); i++) {
if (i <= dim) {
blocknum *= sizes[i];
}
n *= sizes[i];
}
blocksize = n / blocknum;
c10::SmallVector<float, N> ast;
ast.resize(n);
if (flag) {
ast = c10::SmallVector<float, N>(n, 1);
} else {
ast = c10::SmallVector<float, N>(n, 0);
}
for (uint i = 0; i < index.size(); i++) {
int start = 0;
int end = 0;
int idx = index[i];
int k = idx;
int count = 0;
while (k < blocknum) {
start = blocksize * k;
end = start + blocksize;
for (int j = start; j < end; j++) {
ast[j] = value;
}
count++;
k = idx + sizes[dim] * count;
}
}
return ast;
}
at::Tensor index_fill_d_assist_help(const at::Tensor &self, const at::Tensor &index, int64_t dim, at::Scalar value,
bool flag)
{
c10::SmallVector<float, N> assist;
at::IntArrayRef size = self.sizes();
std::vector<int> index_vector;
for (int i = 0; i < index.sizes()[0]; i++) {
int index_temp = static_cast<int>(op_plugin::utils::get_scalar_float_value(index[i].item()));
index_vector.push_back(index_temp);
}
float value_float = op_plugin::utils::get_scalar_float_value(value);
assist = index_fill_d_assist_help_init(dim, size, index_vector, flag, value_float);
at::Tensor assist_help = at::from_blob(assist.data(), size, dtype(at::ScalarType::Float));
return assist_help.to(at::device(torch_npu::utils::get_npu_device_type()));
}
at::Tensor &index_fill_d_nocheck(at::Tensor &result, const at::Tensor &self, int64_t dim, const at::Tensor &index,
at::Scalar value)
{
if (self.numel() == 0) {
return result;
}
at::Scalar value_zeros = at::Scalar(0.0);
const at::Tensor *aclInput = &self;
at::Tensor assist_help1 = index_fill_d_assist_help(self, index, dim, value_zeros, true);
at::Tensor assist_help2 = index_fill_d_assist_help(self, index, dim, value, false);
at::ScalarType self_type = self.scalar_type();
assist_help1 = assist_help1.scalar_type() == self_type ?
assist_help1 :
at_npu::native::custom_ops::_npu_dtype_cast(assist_help1, self_type);
assist_help2 = assist_help2.scalar_type() == self_type ?
assist_help2 :
at_npu::native::custom_ops::_npu_dtype_cast(assist_help2, self_type);
at_npu::native::OpCommand cmd;
cmd.Name("IndexFillD").Input(self).Input(assist_help1).Input(assist_help2).Attr("dim", dim).Output(result).Run();
return result;
}
}
at::Tensor index_fill(const at::Tensor &self, int64_t dim, const at::Tensor &index, const at::Scalar &value)
{
at::Tensor result = npu_preparation::apply_tensor(self);
index_fill_d_nocheck(result, self, dim, index, value);
return result;
}
at::Tensor &index_fill_(at::Tensor &self, int64_t dim, const at::Tensor &index, const at::Scalar &value)
{
at::IntArrayRef shape_self = self.sizes();
index_fill_d_check_index(shape_self, index, dim);
if (!npu_utils::check_match(&self)) {
at::Tensor contiguous_self = npu_utils::format_contiguous(self);
index_fill_d_nocheck(contiguous_self, contiguous_self, dim, index, value);
npu_utils::format_fresh_view(self, contiguous_self);
} else {
index_fill_d_nocheck(self, self, dim, index, value);
}
return self;
}
at::Tensor index_fill(const at::Tensor &self, int64_t dim, const at::Tensor &index, const at::Tensor &value)
{
at::IntArrayRef shape_self = self.sizes();
index_fill_d_check_index(shape_self, index, dim);
TORCH_CHECK(value.dim() == 0, "Value should be a 0-dimensional tensor,but got ", value.dim(),
OPS_ERROR(ErrCode::PARAM));
at::Scalar value_scalar = value.item();
at::Tensor result = npu_preparation::apply_tensor(self);
index_fill_d_nocheck(result, self, dim, index, value_scalar);
return result;
}
at::Tensor &index_fill_(at::Tensor &self, int64_t dim, const at::Tensor &index, const at::Tensor &value)
{
at::IntArrayRef shape_self = self.sizes();
index_fill_d_check_index(shape_self, index, dim);
TORCH_CHECK(value.dim() == 0, "Value should be a 0-dimensional tensor,but got ", value.dim(),
OPS_ERROR(ErrCode::PARAM));
at::Scalar value_scalar = value.item();
if (!npu_utils::check_match(&self)) {
at::Tensor contiguous_self = npu_utils::format_contiguous(self);
index_fill_d_nocheck(contiguous_self, contiguous_self, dim, index, value_scalar);
npu_utils::format_fresh_view(self, contiguous_self);
} else {
index_fill_d_nocheck(self, self, dim, index, value_scalar);
}
return self;
}
}