#include "op_plugin/AclOpsInterface.h"
#include "op_plugin/utils/OpAdapter.h"
namespace acl_op {
using npu_preparation = at_npu::native::OpPreparation;
using npu_utils = at_npu::native::NpuUtils;
at::Tensor& fill_diagonal_out_npu(
at::Tensor& result,
const at::Tensor& self,
const at::Scalar& value,
bool wrap)
{
float fill_value = op_plugin::utils::get_scalar_float_value(value);
at_npu::native::OpCommand cmd;
cmd.Name("FillDiagonal")
.Input(self)
.Output(result)
.Attr("fill_value", fill_value)
.Attr("wrap", wrap)
.Run();
return result;
}
at::Tensor& fill_diagonal_(at::Tensor& self, const at::Scalar& fill_value, bool wrap)
{
npu_preparation::CastBackToOriFormat(self);
if (!npu_utils::check_match(&self)) {
at::Tensor contiguous_self = npu_utils::format_contiguous(self);
fill_diagonal_out_npu(contiguous_self, contiguous_self, fill_value, wrap);
npu_utils::format_fresh_view(self, contiguous_self);
} else {
fill_diagonal_out_npu(self, self, fill_value, wrap);
}
return self;
}
}