#include "SparseTensorUtils.h"
#include <ATen/native/DispatchStub.h>
namespace sparse {
at::Tensor flatten_indices_npu_kernel(const at::Tensor& indices, c10::IntArrayRef size)
{
std::vector<int64_t> flatten_size(size.size(), 1);
for (size_t i = size.size() - 1; i > 0; i--) {
flatten_size[i - 1] = flatten_size[i] * size[i];
}
auto tensor_temp = torch::zeros({indices.size(1)}, indices.options());
for (size_t i = 0; i < size.size(); i++) {
tensor_temp += indices[i] * flatten_size[i];
}
return tensor_temp;
}
SparseTensor& mul_out_sparse_zerodim(SparseTensor& r, const SparseTensor& t, const at::Tensor& value)
{
AT_ASSERT(r.is_sparse());
AT_ASSERT(t.is_sparse());
AT_ASSERT(value.dim() == 0);
at::Tensor value_;
if (value.is_sparse()) {
if (value._nnz() == 0) {
r.resize_as_(t);
return r.zero_();
}
value_ = value.values();
} else {
value_ = value;
}
AT_ASSERT(value_.numel() == 1);
if (is_same_tensor(r, t)) {
r._values().mul_(value_);
} else {
r.resize_as_(t);
auto indices = r._indices();
indices.resize_as_(t._indices());
indices.copy_(t._indices());
at::Tensor r_values = r._values();
at::mul_out(r_values, t._values(), value_);
get_sparse_impl(r)->set_nnz_and_narrow(t._nnz());
r._coalesced_(t.is_coalesced());
}
return r;
}
SparseTensor& mul_out_sparse_scalar(SparseTensor& r, const SparseTensor& t, const at::Scalar& value)
{
return mul_out_sparse_zerodim(r, t, at::native::wrapped_scalar_tensor(value));
}
}
namespace at {
namespace native {
using flatten_indices_fn = at::Tensor (*)(const at::Tensor& indices, at::IntArrayRef size);
DECLARE_DISPATCH(flatten_indices_fn, flatten_indices_stub);
REGISTER_PRIVATEUSE1_DISPATCH(flatten_indices_stub, &::sparse::flatten_indices_npu_kernel);
}
}