#include "op_plugin/AclOpsInterface.h"
#include "op_plugin/utils/OpAdapter.h"
namespace acl_op {
using npu_preparation = at_npu::native::OpPreparation;
namespace {
at::Tensor& repeat_interleave_out_nocheck(at::Tensor& result, const at::Tensor& self, int64_t repeats, int64_t dim)
{
at::Scalar repeat = repeats;
at_npu::native::OpCommand cmd;
cmd.Name("RepeatInterleave")
.Input(self)
.Input(repeat, at::kLong)
.Output(result)
.Attr("axis", dim)
.Run();
return result;
}
at::Tensor& repeat_interleave_out_nocheck(at::Tensor& result, const at::Tensor& self, const at::Tensor& repeats,
int64_t dim)
{
at_npu::native::OpCommand cmd;
cmd.Name("RepeatInterleave")
.Input(self)
.Input(repeats)
.Output(result)
.Attr("axis", dim)
.Run();
return result;
}
void check_dim_valid(int64_t real_dim, int64_t self_dim)
{
int64_t dim_min = std::min(-self_dim, self_dim - 1);
int64_t dim_max = std::max(-self_dim, self_dim - 1);
TORCH_CHECK(
(real_dim >= dim_min) && (real_dim <= dim_max),
"dim value should be in the range of [-x, x-1], x is the dimension number of input tensor.",
OPS_ERROR(ErrCode::VALUE));
}
}
at::Tensor repeat_interleave_common_nocheck(
const at::Tensor& self,
int64_t repeats,
c10::optional<int64_t> dim)
{
int64_t real_dim = dim.value_or(0);
int64_t self_dim = self.dim();
check_dim_valid(real_dim, self_dim);
TORCH_CHECK(
repeats >= 1,
"repeats can not be negative.",
OPS_ERROR(ErrCode::VALUE));
at::Tensor self_tensor = self;
if (!dim.has_value()) {
self_tensor = at::flatten(self_tensor);
}
if (repeats == 1) {
return self_tensor;
}
auto op_infer_output_size = op_infer::repeat_interleave_npu_output_size(self_tensor, repeats, real_dim);
at::Tensor result = npu_preparation::apply_tensor_with_format(self_tensor, op_infer_output_size, ACL_FORMAT_ND);
repeat_interleave_out_nocheck(result, self_tensor, repeats, real_dim);
return result;
}
at::Tensor repeat_interleave_common_nocheck(
const at::Tensor& self,
const at::Tensor& repeats,
c10::optional<int64_t> dim)
{
int64_t real_dim = dim.value_or(0);
int64_t self_dim = self.dim();
check_dim_valid(real_dim, self_dim);
at::Tensor self_tensor = self;
at::Tensor repeats_tensor = repeats;
if (repeats.dim() == 0) {
repeats_tensor.unsqueeze_(0);
}
if (!dim.has_value()) {
self_tensor = at::flatten(self_tensor);
}
TORCH_CHECK(
(repeats.size(0) == self_tensor.size(real_dim)) || (repeats.size(0) == 1),
"repeats must have the same size as input along dim.", OPS_ERROR(ErrCode::VALUE));
repeats_tensor = acl_op::_npu_dtype_cast(repeats_tensor, at::ScalarType::Float);
auto op_infer_output_size = op_infer::repeat_interleave_npu_output_size(self_tensor, repeats_tensor, real_dim);
at::Tensor result = npu_preparation::apply_tensor_with_format(self_tensor, op_infer_output_size, ACL_FORMAT_ND);
repeat_interleave_out_nocheck(result, self_tensor, repeats, real_dim);
return result;
}
}