#include "op_plugin/AclOpsInterface.h"
#include "op_plugin/utils/OpAdapter.h"
namespace acl_op {
using npu_preparation = at_npu::native::OpPreparation;
at::Tensor conv_tbc(const at::Tensor &self, const at::Tensor &weight, const at::Tensor &bias, int64_t pad)
{
TORCH_CHECK(self.dim() == 3, "Input must have 3 dims: time, batch, in_channel." + OPS_ERROR(ErrCode::PARAM));
TORCH_CHECK(weight.dim() == 3, "Weight tensor must have 3 dims: kernel_width,"
" in_channels, out_channels." + OPS_ERROR(ErrCode::PARAM));
TORCH_CHECK(bias.dim() == 1, "Bias must be 1-D." + OPS_ERROR(ErrCode::PARAM));
TORCH_CHECK(self.size(2) == weight.size(1), "Input dim 2 (input channels) "
"is not == dim 1 in the weight tenso." + OPS_ERROR(ErrCode::PARAM));
TORCH_CHECK(weight.size(2) == bias.size(0), "Bias size must equal dim 2 in "
"the weight tensor (output channels)." + OPS_ERROR(ErrCode::PARAM));
int64_t C = weight.size(2);
int64_t W = (self.size(0) + 2 * pad - (weight.size(0) - 1) - 1) + 1;
TORCH_CHECK(W > 0, "W has to be positive, but got ", W, OPS_ERROR(ErrCode::VALUE));
c10::SmallVector<int64_t, SIZE> output_size = {self.size(1), C, 1, W};
at::Tensor result = npu_preparation::apply_tensor_with_format(self, output_size, ACL_FORMAT_NCHW);
c10::SmallVector<int64_t, N> paddings = {0, 0, pad, pad};
c10::SmallVector<int64_t, N> strides_size = {1, 1, 1, 1};
c10::SmallVector<int64_t, N> dilations = {1, 1, 1, 1};
at::Tensor self_tensor = self.transpose(0, 2).transpose(0, 1).unsqueeze(2);
at::Tensor weight_tensor = weight.transpose(0, 2).unsqueeze(2);
at_npu::native::OpCommand cmd;
cmd.Name("Conv2D")
.Input(self_tensor, "x")
.Input(weight_tensor, "filter")
.Input(bias)
.Output(result, "y")
.Attr("pads", paddings)
.Attr("strides", strides_size)
.Attr("dilations", dilations)
.Attr("data_format", (string) "NCHW")
.Run();
result = result.squeeze(2).transpose(0, 2).transpose(1, 2);
return result;
}
}