#include <ATen/NamedTensorUtils.h>
#include <ATen/native/TypeProperties.h>
#include "op_plugin/AclOpsInterface.h"
#include "op_plugin/OpApiInterface.h"
#include "op_plugin/utils/op_api_common.h"
namespace op_api {
using npu_preparation = at_npu::native::OpPreparation;
using npu_utils = at_npu::native::NpuUtils;
static c10::SmallVector<at::Tensor, op_infer::N> cat_dest_tensor_list_opapi(
const at::MaterializedITensorListRef& tensors)
{
c10::SmallVector<at::Tensor, op_infer::N> dst_tensor_list;
for (at::Tensor tensor : tensors) {
if (tensor.dim() == 1 && tensor.sizes()[0] == 0) {
continue;
}
dst_tensor_list.emplace_back(tensor);
}
return dst_tensor_list;
}
static c10::SmallVector<int64_t, op_infer::SIZE> cat_npu_output_size_opapi(
c10::SmallVector<at::Tensor, op_infer::N>& tensors, int64_t dimension)
{
bool allSkipped = true;
int64_t nDims = 0;
at::Tensor* notSkippedTensor;
int numInputs = static_cast<int64_t>(tensors.size());
auto should_skip = [](const at::Tensor* t) { return t->nbytes() == 0 && t->dim() == 1; };
for (int i = 0; i < numInputs; i++) {
if (should_skip(static_cast<at::Tensor*>(&tensors[i]))) {
continue;
}
allSkipped = false;
notSkippedTensor = static_cast<at::Tensor*>(&tensors[i]);
nDims = notSkippedTensor->dim();
break;
}
if (allSkipped) {
c10::SmallVector<int64_t, op_infer::SIZE> size = {0};
return size;
}
int64_t cat_dim_size = 0;
for (int i = 0; i < numInputs; i++) {
at::Tensor* tensor = static_cast<at::Tensor*>(&tensors[i]);
if (should_skip(tensor)) {
continue;
}
cat_dim_size += tensor->size(dimension);
}
c10::SmallVector<int64_t, op_infer::SIZE> size;
size.resize(nDims);
for (int dim = 0; dim < nDims; dim++) {
int64_t result_dim_size = notSkippedTensor->size(dim);
if (dim == dimension) {
result_dim_size = cat_dim_size;
}
size[dim] = result_dim_size;
}
return size;
}
inline void cat_check_no_zero_dim(const at::MaterializedITensorListRef& tensors)
{
size_t i = 0;
for (const at::Tensor& t : tensors) {
TORCH_CHECK(
t.dim() > 0,
"zero-dimensional tensor (at position ", i, ") cannot be concatenated" + OPS_ERROR(ErrCode::PARAM));
i++;
}
}
at::Tensor& cat_out(const at::ITensorListRef& tensors, int64_t dim, at::Tensor& out)
{
DO_COMPATIBILITY(aclnnCat, acl_op::cat_out(tensors, dim, out));
auto materialized = tensors.materialize();
cat_check_no_zero_dim(materialized);
c10::SmallVector<at::Tensor, op_infer::N> inputTensors = cat_dest_tensor_list_opapi(materialized);
at::TensorList tensor_list(inputTensors.begin(), inputTensors.end());
int64_t dim_post_expr = 0;
if (inputTensors.size() > 0) {
dim_post_expr = inputTensors[0].dim();
} else {
npu_preparation::check_tensor({materialized[0].get()}, out, at::IntArrayRef({0}));
return out;
}
dim = op_plugin::utils::make_warp_dim(dim, dim_post_expr);
auto maybe_outnames = at::namedinference::compute_cat_outnames(materialized);
auto outputSize = cat_npu_output_size_opapi(inputTensors, dim);
npu_preparation::check_tensor({materialized[0].get()}, out, at::IntArrayRef(outputSize));
EXEC_NPU_CMD(aclnnCat, tensor_list, dim, out);
at::namedinference::propagate_names_if_nonempty(out, maybe_outnames);
return out;
}
at::Tensor cat(const at::ITensorListRef& tensors, int64_t dim)
{
DO_COMPATIBILITY(aclnnCat, acl_op::cat(tensors, dim));
auto materialized = tensors.materialize();
cat_check_no_zero_dim(materialized);
c10::SmallVector<at::Tensor, op_infer::N> inputTensors = cat_dest_tensor_list_opapi(materialized);
at::TensorList tensor_list(inputTensors.begin(), inputTensors.end());
at::ScalarType high_type = at::native::result_type(materialized);
int64_t dim_post_expr = 0;
if (inputTensors.size() > 0) {
dim_post_expr = inputTensors[0].dim();
} else {
at::Tensor result = npu_preparation::apply_tensor_without_format(materialized[0]);
return result;
}
dim = op_plugin::utils::make_warp_dim(dim, dim_post_expr);
auto maybe_outnames = at::namedinference::compute_cat_outnames(materialized);
auto outputSize = cat_npu_output_size_opapi(inputTensors, dim);
at::Tensor result =
npu_preparation::apply_tensor_without_format(outputSize, inputTensors[0].options().dtype(high_type));
EXEC_NPU_CMD(aclnnCat, tensor_list, dim, result);
at::namedinference::propagate_names_if_nonempty(result, maybe_outnames);
return result;
}
at::Tensor& cat_out(at::TensorList tensors, at::Dimname dim, at::Tensor& out)
{
DO_COMPATIBILITY(aclnnCat, acl_op::cat_out(tensors, dim, out));
return at::cat_out(out, tensors, dimname_to_position(tensors[0], dim));
}
at::Tensor cat(at::TensorList tensors, at::Dimname dim)
{
DO_COMPATIBILITY(aclnnCat, acl_op::cat(tensors, dim));
return at::cat(tensors, dimname_to_position(tensors[0], dim));
}
}