#include "csrc/OpApiCommon.h"
#include "csrc/functions.h"
using namespace std;
at::Tensor scatter_max_backward(const at::Tensor& x, const at::Tensor& segment_ids, const at::Tensor& num_segments)
{
c10::SmallVector<int64_t, SIZE> output_size;
auto num_segments_value = num_segments.item().toLong();
output_size.push_back(num_segments_value);
auto x_sizes = x.sizes();
auto segment_ids_dims = segment_ids.dim();
copy(x_sizes.begin() + segment_ids_dims, x_sizes.end(), std::back_inserter(output_size));
at::Tensor out = at::empty(output_size, x.options());
at_npu::native::OpCommand cmd;
cmd.Name("UnsortedSegmentSum")
.Input(x)
.Input(segment_ids)
.Input(num_segments)
.Output(out)
.Attr("check_ids", true)
.Run();
return out;
}
void scatter_max_validate(const at::Tensor& src, const at::Tensor& index, const at::Tensor& res)
{
auto indexSizes = index.sizes();
auto srcSizes = src.sizes();
auto resSizes = res.sizes();
int32_t indexLength = 1;
for (size_t i = 1; i < static_cast<size_t>(index.dim()); i++) {
indexLength *= indexSizes[i];
}
auto src_dims = srcSizes.size();
auto index_dims = indexSizes.size();
auto res_dims = resSizes.size();
TORCH_CHECK(src_dims != 0 && index_dims != 0, "src and index should not be empty.");
TORCH_CHECK(res_dims == src_dims, "out's dimension should be equal to src's dimension.");
for (size_t i = 1; i < static_cast<size_t>(res.dim()); i++) {
TORCH_CHECK(srcSizes[i] == resSizes[i], "src and out should have the same size except for dim 0.");
}
TORCH_CHECK(indexLength == 1,
"all the dims's range except the first dim of input tensor [index] should be equal to 1.");
TORCH_CHECK(
index.sizes()[0] == src.sizes()[0], "input's src size of dim 0 should be equal to index's size.");
}
std::tuple<at::Tensor, at::Tensor> scatter_max(
const at::Tensor& src, const at::Tensor& index, c10::optional<at::Tensor> out)
{
auto sizes = src.sizes().vec();
auto idxMaxVal = index.max().item().toLong();
TORCH_CHECK(idxMaxVal >= 0, "invalid index value.");
sizes[0] = idxMaxVal + 1;
float ninf = -std::numeric_limits<float>::infinity();
at::Tensor res = out.value_or(at::empty(sizes, src.options().dtype(at::kFloat)).fill_(ninf));
at::Tensor argmax = at::empty(res.sizes(), res.options().dtype(at::kInt)).fill_(-1);
scatter_max_validate(src, index, res);
EXEC_NPU_CMD(aclnnScatterMaxV1, src, index, res, argmax);
res.masked_fill_(res == ninf, 0.0f);
EXEC_NPU_CMD(aclnnScatterMaxArgmaxV1, src, index, res, argmax);
auto argmaxInvalidVal = src.sizes().vec()[0];
argmax.masked_fill_(argmax == -1, argmaxInvalidVal);
return std::tie(res, argmax);
}