#include <ATen/native/SparseTensorUtils.h>
#include <torch/library.h>
#include <torch/torch.h>
#include "torch_npu/csrc/core/npu/DeviceUtils.h"
#include "op_plugin/SparseOpsInterface.h"
#include "SparseTensorUtils.h"
namespace sparse {
at::Tensor& add_out_dense_sparse_npu(
at::Tensor& r_,
const at::Tensor& dense,
const at::sparse::SparseTensor& sparse,
const at::Scalar& value)
{
TORCH_CHECK(torch_npu::utils::is_npu(dense), "sparse add: expected 'self' to be a NPU tensor",
OPS_ERROR(ErrCode::VALUE));
TORCH_CHECK(torch_npu::utils::is_npu(sparse), "sparse add: expected 'other' to be a NPU tensor",
OPS_ERROR(ErrCode::VALUE));
TORCH_CHECK(torch_npu::utils::is_npu(r_), "sparse add: expected 'out' to be a NPU tensor",
OPS_ERROR(ErrCode::VALUE));
TORCH_CHECK(dense.sizes().equals(sparse.sizes()),
"add: expected 'self' and 'other' to have same size, but self has size ",
dense.sizes(), " while other has size ", sparse.sizes(),
" (FYI: dense-sparse addition does not currently support broadcasting)", OPS_ERROR(ErrCode::VALUE));
const int64_t nnz = sparse._nnz();
if (nnz == 0) {
r_.resize_as_(dense);
r_.copy_(dense);
return r_;
}
auto common_dtype = at::result_type(dense, sparse);
TORCH_CHECK(c10::canCast(common_dtype, r_.scalar_type()), "Can't convert result type ", common_dtype,
" to output ", r_.scalar_type(), OPS_ERROR(ErrCode::TYPE));
at::Tensor r = r_;
if (r_.scalar_type() != common_dtype) {
r = at::empty_like(dense, r_.options().dtype(common_dtype));
}
at::Tensor dense_buffer = dense.to(common_dtype);
at::Tensor values = sparse._values().to(common_dtype);
if (at::sparse::is_same_tensor(r, dense_buffer)) {
TORCH_CHECK(r_.is_contiguous(),
"add: NPU dense-sparse addition with a non-contiguous output tensor does not work",
OPS_ERROR(ErrCode::VALUE));
} else {
r.resize_as_(dense);
r.copy_(dense_buffer);
}
at::Tensor indices = sparse._indices();
int64_t n_dim = dense.dim();
int64_t n_dimI = sparse.sparse_dim();
if (values.numel() == 0) {
return r_;
}
at::Tensor indices_1D = at::sparse::flatten_indices(indices, sparse.sizes(), false);
int64_t view_rows = 1;
int64_t view_columns = 1;
for (int i = 0; i < n_dimI; i++) {
view_rows *= r.size(i);
}
for (int i = n_dimI; i < n_dim; i++) {
view_columns *= r.size(i);
}
at::Tensor r_view = r.view({view_rows, view_columns});
values = values.reshape({nnz, view_columns});
r_view.index_add_(0, indices_1D, values, value);
r_.copy_(r);
return r_;
}
at::sparse::SparseTensor& add_out_sparse(
const at::sparse::SparseTensor& t,
const at::sparse::SparseTensor& src,
const at::Scalar& value,
at::sparse::SparseTensor& r_)
{
if (!t.is_sparse()) {
return add_out_dense_sparse_npu(r_, t, src, value);
}
TORCH_CHECK(src.is_sparse(), "add(sparse, dense) is not supported. Use add(dense, sparse) instead.",
OPS_ERROR(ErrCode::VALUE));
TORCH_CHECK(torch_npu::utils::is_npu(t), "add: expected 'self' to be NPU, but got ", t.device(),
OPS_ERROR(ErrCode::VALUE));
TORCH_CHECK(torch_npu::utils::is_npu(src), "add: expected 'other' to be NPU, but got ", src.device(),
OPS_ERROR(ErrCode::VALUE));
TORCH_CHECK(torch_npu::utils::is_npu(r_), "add: expected 'out' to be NPU, but got ", r_.device(),
OPS_ERROR(ErrCode::VALUE));
auto common_dtype = at::result_type(t, src);
TORCH_CHECK(c10::canCast(common_dtype, r_.scalar_type()), "Can't convert result type ",
common_dtype, " to output ", r_.scalar_type(), OPS_ERROR(ErrCode::TYPE));
TORCH_CHECK(t.sizes().equals(src.sizes()), "add: expected 'self' and 'other' to have same size, but ",
t.sizes(), " != ", src.sizes(), OPS_ERROR(ErrCode::VALUE));
if (src._nnz() == 0) {
return at::copy_sparse_to_sparse_(r_, t);
}
if (t._nnz() == 0) {
return mul_out_sparse_scalar(r_, src, value);
}
TORCH_CHECK(at::sparse::is_same_density(t, src),
"add: expected 'self' and 'other' to have same density, but 'self' has ",
t.sparse_dim(), " sparse dimensions while 'other' has ", src.sparse_dim(), " sparse dimensions",
OPS_ERROR(ErrCode::VALUE));
at::Tensor t_indices_ = t._indices();
at::Tensor s_indices_ = src._indices();
at::Tensor t_values_ = t._values().to(common_dtype);
at::Tensor s_values_ = src._values().to(common_dtype);
s_values_ = s_values_.mul(value);
at::Tensor r_indices_ = at::cat({t_indices_, s_indices_}, 1);
at::Tensor r_values_ = at::cat({t_values_, s_values_}, 0);
if (r_.scalar_type() != common_dtype) {
at::sparse::SparseTensor promoted = at::empty({0}, r_.options().dtype(common_dtype));
promoted.resize_as_(src);
at::sparse::alias_into_sparse(promoted, r_indices_, r_values_);
promoted = promoted.coalesce();
r_values_ = promoted._values().to(r_.scalar_type());
r_indices_ = promoted._indices();
} else {
r_.resize_as_(src);
}
alias_into_sparse(r_, r_indices_, r_values_);
if (r_._nnz() > r_.numel()) {
auto c = r_.coalesce();
alias_into_sparse(r_, c._indices(), c._values());
}
return r_;
}
}