#include "op_plugin/AclOpsInterface.h"
#include "op_plugin/utils/OpAdapter.h"
namespace acl_op {
#if VERSION_BETWEEN(V2R1, VERSION_NEWEST)
using npu_preparation = at_npu::native::OpPreparation;
constexpr int ATTRS_DIM = 2;
constexpr int TENSORS_DIM = 4;
constexpr int INPUT_H_INDEX = 2;
constexpr int INPUT_W_INDEX = 3;
constexpr int WEIGHT_W_INDEX = 3;
at::Tensor npu_quant_conv2d_out(const at::Tensor& input, const at::Tensor& weight, const at::Tensor& scale,
c10::IntArrayRef stride, c10::IntArrayRef pad,
c10::IntArrayRef dilation, int64_t groups,
int64_t offset_x, c10::string_view round_mode, at::Tensor& output,
c10::optional<int64_t> input_dtype, c10::optional<int64_t> weight_dtype,
c10::optional<int64_t> output_dtype, const c10::optional<at::Tensor>& bias_opt,
const c10::optional<at::Tensor>& offset)
{
TORCH_CHECK(stride.size() >= ATTRS_DIM, "stride has to contain more than 2 elements, but got ", stride.size());
TORCH_CHECK(pad.size() >= ATTRS_DIM, "padding has to contain more than 2 elements, but got ", pad.size());
TORCH_CHECK(dilation.size() >= ATTRS_DIM, "dilation has to contain more than 2 elements, but got ",
dilation.size());
TORCH_CHECK(output_dtype.has_value(), "output_dtype can not be None");
TORCH_CHECK(output_dtype.value() == static_cast<int64_t>(at::ScalarType::Half), "only support float16 as outputdtype");
const at::Tensor &bias = c10::value_or_else(bias_opt, [] { return at::Tensor(); });
c10::SmallVector<int64_t, N> strides = {1, 1, stride[0], stride[1]};
c10::SmallVector<int64_t, N> paddings = {pad[0], pad[0], pad[1], pad[1]};
c10::SmallVector<int64_t, N> dilations = {1, 1, dilation[0], dilation[1]};
int64_t dtype_enum = 0;
if (output_dtype.value() == static_cast<int64_t>(at::ScalarType::Half)) {
dtype_enum = 1;
}
at_npu::native::OpCommand cmd;
cmd.Name("QuantConv2D").Input(input, "x").Input(weight, "filter").Input(scale, "scale");
if (bias.defined()) {
cmd.Input(bias);
}
cmd.Output(output, "y")
.Attr("dtype", dtype_enum)
.Attr("strides", strides)
.Attr("pads", paddings)
.Attr("dilations", dilations)
.Attr("groups", groups)
.Attr("data_format", static_cast<std::string>("NCHW"))
.Attr("offset_x", offset_x)
.Attr("round_mode", static_cast<std::string>("rint"))
.Run();
return output;
}
at::Tensor npu_quant_conv2d(const at::Tensor& input, const at::Tensor& weight, const at::Tensor& scale,
c10::IntArrayRef strides, c10::IntArrayRef pads, c10::IntArrayRef dilations,
int64_t groups, int64_t offset_x, c10::string_view round_mode,
c10::optional<int64_t> output_dtype, const c10::optional<at::Tensor>& bias,
const c10::optional<at::Tensor>& offset, c10::optional<int64_t> input_dtype,
c10::optional<int64_t> weight_dtype)
{
TORCH_CHECK(input.dim() >= TENSORS_DIM, "input has to be more than 4D, but got Tensor of dimension ", input.dim());
TORCH_CHECK(weight.dim() >= TENSORS_DIM, "weight has to more than 4D, but got Tensor of dimension ", weight.dim());
TORCH_CHECK(strides.size() >= ATTRS_DIM, "stride has to contain more than 2 elements, but got ", strides.size());
TORCH_CHECK(pads.size() >= ATTRS_DIM, "padding has to contain more than 2 elements, but got ", pads.size());
TORCH_CHECK(dilations.size() >= ATTRS_DIM, "dilation has to contain more than 2 elements, but got ",
dilations.size());
TORCH_CHECK(weight.size(WEIGHT_W_INDEX) != 0, "4th dim of weight cannot be 0");
TORCH_CHECK(strides[0] * strides[1] != 0, "Stride cannot contain 0")
int64_t N = input.size(0);
int64_t H = input.size(INPUT_H_INDEX);
int64_t W = input.size(INPUT_W_INDEX);
int64_t Co = weight.size(0);
auto kernel_size = weight.sizes().slice(2);
int64_t Ho = (H + 2 * pads[0] - dilations[0] * (kernel_size[0] - 1) - 1) / strides[0] + 1;
int64_t Wo = (W + 2 * pads[1] - dilations[1] * (kernel_size[1] - 1) - 1) / strides[1] + 1;
TORCH_CHECK(Ho > 0, "Ho has to be positive, but got ", Ho);
TORCH_CHECK(Wo > 0, "Wo has to be positive, but got ", Wo);
c10::SmallVector<int64_t, SIZE> output_size = {N, Co, Ho, Wo};
c10::TensorOptions options = input.options().dtype(at::kHalf);
at::Tensor result = npu_preparation::apply_tensor_with_format(output_size, options, ACL_FORMAT_NCHW);
acl_op::npu_quant_conv2d_out(input, weight, scale, strides, pads, dilations, groups, offset_x, round_mode,
result, input_dtype, weight_dtype, output_dtype, bias, offset);
return result;
}
#endif
}