#include "op_plugin/AclOpsInterface.h"
#include "op_plugin/utils/OpAdapter.h"
#include "op_plugin/utils/custom_functions/aclops/inner_compute.h"
namespace acl_op {
using npu_preparation = at_npu::native::OpPreparation;
namespace {
std::tuple<at::Tensor&, at::Tensor&> max_v1_out_nocheck(
at::Tensor& output,
at::Tensor& indices,
const at::Tensor& self,
int64_t dim,
bool keepdim) {
at_npu::native::OpCommand cmd;
cmd.Name("ArgMaxWithValue")
.Input(self)
.Output(indices)
.Output(output)
.Attr("dimension", dim)
.Attr("keep_dims", keepdim)
.Run();
return std::tie(output, indices);
}
}
std::tuple<at::Tensor, at::Tensor> npu_max(const at::Tensor& self, int64_t dim, bool keepdim) {
c10::SmallVector<int64_t, SIZE> dims = {dim};
c10::SmallVector<int64_t, SIZE> output_size = op_infer::reduce_ops_npu_output_size(self, dims, keepdim);
c10::SmallVector<int64_t, SIZE> indices_size = op_infer::reduce_ops_npu_output_size(self, dims, keepdim);
int64_t npu_format = output_size.empty() ? ACL_FORMAT_NCHW : npu_preparation::get_tensor_npu_format(self);
at::Tensor outputs = npu_preparation::apply_tensor_with_format(
output_size, self.options(), npu_format);
at::Tensor indices = npu_preparation::apply_tensor_with_format(
indices_size, self.options().dtype(at::kInt), ACL_FORMAT_NCHW);
max_v1_out_nocheck(outputs, indices, self, dim, keepdim);
return std::tie(outputs, indices);
}
std::tuple<at::Tensor, at::Tensor> npu_max(const at::Tensor& self, at::Dimname dim, bool keepdim) {
return acl_op::npu_max(self, dimname_to_position(self, dim), keepdim);
}
at::Tensor npu_max_backward_symint(const at::Tensor &grad, int64_t dim, const at::Tensor &indices,
c10::SymIntArrayRef sizes_symint, bool keepdim)
{
at::IntArrayRef sizes = c10::asIntArrayRefUnchecked(sizes_symint);
at::Tensor new_grad = grad;
at::Tensor new_indices = indices;
if (keepdim && sizes.size() > 0) {
new_grad = grad.squeeze(dim);
new_indices = indices.squeeze(dim);
}
if (new_indices.dtype() == at::kLong) {
new_indices = at_npu::native::custom_ops::_npu_dtype_cast(new_indices, at::kInt);
}
auto grad_input = acl_op::npu_scatter(at::zeros(sizes, new_grad.options()), new_indices, new_grad, dim);
return grad_input;
}
}