#include <ATen/NamedTensorUtils.h>
#include "op_plugin/AclOpsInterface.h"
#include "op_plugin/utils/OpAdapter.h"
namespace acl_op {
using npu_preparation = at_npu::native::OpPreparation;
using npu_utils = at_npu::native::NpuUtils;
namespace {
std::tuple<at::Tensor &, at::Tensor &> max_out_npu_nocheck(at::Tensor &output, at::Tensor &indices,
const at::Tensor &self, int64_t dim, bool keepdim)
{
at_npu::native::OpCommand cmd;
cmd.Name("ArgMaxWithValue")
.Input(self)
.Output(indices)
.Output(output)
.Attr("dimension", dim)
.Attr("keep_dims", keepdim)
.Run();
return std::tie(output, indices);
}
at::Tensor &max_out_npu_nocheck(at::Tensor &result, const at::Tensor &self, const at::Tensor &other)
{
at_npu::native::OpCommand cmd;
cmd.Name("Maximum").Input(self).Input(other).Output(result).Run();
return result;
}
at::Tensor &max_out_npu_nocheck(at::Tensor &result, const at::Tensor &self, at::IntArrayRef dims, bool keepdim)
{
at_npu::native::OpCommand cmd;
cmd.Name("ReduceMax").Input(self).Input(dims).Output(result).Attr("keep_dims", keepdim).Run();
return result;
}
at::Tensor &max_out_npu_nocheck(at::Tensor &result, const at::Tensor &self, const at::Scalar &other)
{
at_npu::native::OpCommand cmd;
cmd.Name("Maximum").Input(self).Input(other, self.scalar_type()).Output(result).Run();
return result;
}
}
std::tuple<at::Tensor &, at::Tensor &> max_out(const at::Tensor &self, int64_t dim, bool keepdim, at::Tensor &max,
at::Tensor &max_values)
{
at::SmallVector<int64_t, SIZE> dims = {dim};
auto output_size = op_infer::reduce_ops_npu_output_size(self, dims, keepdim);
npu_preparation::CheckOut({self}, max, ACL_FORMAT_ND, self.scalar_type(), output_size);
npu_preparation::CheckOut({self}, max_values, ACL_FORMAT_ND, at::ScalarType::Long, output_size);
at::Tensor indices_dtype_cast = at_npu::native::custom_ops::_npu_dtype_cast(max_values, at::ScalarType::Int);
bool output_match = npu_utils::check_match(&max);
bool indices_match = npu_utils::check_match(&indices_dtype_cast);
if (!(output_match && indices_match)) {
at::Tensor contiguous_output = output_match ? max : npu_utils::format_contiguous(max);
at::Tensor contiguous_indices =
indices_match ? indices_dtype_cast : npu_utils::format_contiguous(indices_dtype_cast);
max_out_npu_nocheck(contiguous_output, contiguous_indices, self, dim, keepdim);
if (!output_match) {
npu_utils::format_fresh_view(max, contiguous_output);
}
if (!indices_match) {
npu_utils::format_fresh_view(indices_dtype_cast, contiguous_indices);
}
} else {
max_out_npu_nocheck(max, indices_dtype_cast, self, dim, keepdim);
}
indices_dtype_cast = at_npu::native::custom_ops::_npu_dtype_cast(indices_dtype_cast, at::ScalarType::Long);
max_values.copy_(indices_dtype_cast);
return std::tie(max, max_values);
}
std::tuple<at::Tensor &, at::Tensor &> max_out(const at::Tensor &self, at::Dimname dim, bool keepdim,
at::Tensor &max, at::Tensor &max_values)
{
return acl_op::max_out(self, dimname_to_position(self, dim), keepdim, max, max_values);
}
std::tuple<at::Tensor, at::Tensor> max(const at::Tensor &self, int64_t dim, bool keepdim)
{
at::Tensor self_cast = self;
if (self.dtype() == at::ScalarType::Bool || self.dtype() == at::ScalarType::Int) {
self_cast = at_npu::native::custom_ops::_npu_dtype_cast(self, at::ScalarType::Float);
}
at::SmallVector<int64_t, SIZE> dims = {dim};
auto output_size = op_infer::reduce_ops_npu_output_size(self_cast, dims, keepdim);
at::Tensor outputs = npu_preparation::apply_tensor_with_format(output_size, self_cast.options(), ACL_FORMAT_ND);
at::Tensor indices = npu_preparation::apply_tensor_with_format(
output_size, self_cast.options().dtype(at::ScalarType::Int), ACL_FORMAT_ND);
max_out_npu_nocheck(outputs, indices, self_cast, dim, keepdim);
indices = at_npu::native::custom_ops::_npu_dtype_cast(indices, at::ScalarType::Long);
if (self.dtype() == at::ScalarType::Bool || self.dtype() == at::ScalarType::Int) {
outputs = at_npu::native::custom_ops::_npu_dtype_cast(outputs, self.scalar_type());
}
return std::tie(outputs, indices);
}
std::tuple<at::Tensor, at::Tensor> max(const at::Tensor &self, at::Dimname dim, bool keepdim)
{
return at::max(self, dimname_to_position(self, dim), keepdim);
}
at::Tensor &max_out(const at::Tensor &self, const at::Tensor &other, at::Tensor &out)
{
auto output_size = op_infer::broadcast_ops_npu_output_size(self, other);
at::ScalarType high_type = at::native::result_type(self, other);
at::Tensor self_copy = (self.scalar_type() != high_type && !npu_preparation::is_scalar_wrapped_to_tensor(self)) ?
at_npu::native::custom_ops::_npu_dtype_cast(self, high_type) :
self;
at::Tensor other_copy = (other.scalar_type() != high_type && !npu_preparation::is_scalar_wrapped_to_tensor(other)) ?
at_npu::native::custom_ops::_npu_dtype_cast(other, high_type) :
other;
npu_preparation::CheckOut({self_copy, other_copy}, out, self_copy, output_size);
if (!npu_utils::check_match(&out)) {
at::Tensor contiguous_result = npu_utils::format_contiguous(out);
max_out_npu_nocheck(contiguous_result, self_copy, other_copy);
npu_utils::format_fresh_view(out, contiguous_result);
} else {
max_out_npu_nocheck(out, self_copy, other_copy);
}
return out;
}
at::Tensor &maximum_out(const at::Tensor &self, const at::Tensor &other, at::Tensor &out)
{
auto output_size = op_infer::broadcast_ops_npu_output_size(self, other);
npu_preparation::CheckOut({self, other}, out, self, output_size);
if (!npu_utils::check_match(&out)) {
at::Tensor contiguous_result = npu_utils::format_contiguous(out);
max_out_npu_nocheck(contiguous_result, self, other);
npu_utils::format_fresh_view(out, contiguous_result);
} else {
max_out_npu_nocheck(out, self, other);
}
return out;
}
at::Tensor maximum(const at::Tensor &self, const at::Tensor &other)
{
auto output_size_diff = self.sizes();
at::Tensor result_diff = npu_preparation::apply_tensor(self, output_size_diff);
if (npu_preparation::IsCPUScalar(other)) {
max_out_npu_nocheck(result_diff, self, other.item());
return result_diff;
}
auto output_size = op_infer::broadcast_ops_npu_output_size(self, other);
at::ScalarType high_type = at::native::result_type(self, other);
at::Tensor self_copy = (self.scalar_type() != high_type && !npu_preparation::is_scalar_wrapped_to_tensor(self)) ?
at_npu::native::custom_ops::_npu_dtype_cast(self, high_type) :
self;
at::Tensor other_copy = (other.scalar_type() != high_type && !npu_preparation::is_scalar_wrapped_to_tensor(other)) ?
at_npu::native::custom_ops::_npu_dtype_cast(other, high_type) :
other;
at::Tensor result = npu_preparation::apply_tensor(self_copy, output_size);
max_out_npu_nocheck(result, self_copy, other_copy);
return result;
}
at::Tensor amax(const at::Tensor &self, at::IntArrayRef dim, bool keepdim)
{
auto output_size = op_infer::reduce_ops_npu_output_size(self, dim, keepdim);
int64_t npu_format = npu_preparation::get_tensor_npu_format(self);
if (output_size.empty()) {
npu_format = ACL_FORMAT_ND;
}
at::Tensor result = npu_preparation::apply_tensor_with_format(self, output_size, npu_format);
max_out_npu_nocheck(result, self, dim, keepdim);
return result;
}
at::Tensor max(const at::Tensor &self)
{
at::SmallVector<int64_t, SIZE> dims = op_plugin::utils::get_dimlist_for_tensor(self);
return acl_op::amax(self, dims, false);
}
at::Tensor &amax_out(const at::Tensor &self, at::IntArrayRef dim, bool keepdim, at::Tensor &out)
{
auto output_size = op_infer::reduce_ops_npu_output_size(self, dim, keepdim);
npu_preparation::CheckOut({self}, out, ACL_FORMAT_ND, self.scalar_type(), output_size);
if (!npu_utils::check_match(&out)) {
at::Tensor contiguous_result = npu_utils::format_contiguous(out);
max_out_npu_nocheck(contiguous_result, self, dim, keepdim);
npu_utils::format_fresh_view(out, contiguous_result);
} else {
max_out_npu_nocheck(out, self, dim, keepdim);
}
return out;
}
}