#include <ATen/native/TypeProperties.h>
#include "op_plugin/AclOpsInterface.h"
#include "op_plugin/utils/OpAdapter.h"
#include "op_plugin/utils/custom_functions/aclops/inner_compute.h"
namespace acl_op {
using npu_preparation = at_npu::native::OpPreparation;
using calcu_op_util = at_npu::native::CalcuOpUtil;
using npu_utils = at_npu::native::NpuUtils;
namespace {
c10::SmallVector<at::Tensor, N> cat_dest_tensor_list(const at::MaterializedITensorListRef& tensors)
{
auto high_type = at::native::result_type(tensors);
c10::SmallVector<at::Tensor, N> dst_tensor_list;
for (const at::Tensor& t : tensors) {
at::Tensor tensor = t;
if (tensor.dim() == 1 && tensor.sizes()[0] == 0) {
continue;
}
if (tensor.scalar_type() != high_type) {
tensor = at_npu::native::custom_ops::_npu_dtype_cast(tensor, high_type);
}
dst_tensor_list.emplace_back(tensor);
}
return dst_tensor_list;
}
at::Tensor& cat_output_nocheck(at::Tensor& result, const at::MaterializedITensorListRef& tensors, int64_t dim)
{
if (tensors.size() == 1) {
return result.copy_(tensors[0].get());
}
c10::SmallVector<at::Tensor, N> input_tensors = cat_dest_tensor_list(tensors);
int64_t dim_post_expr = 0;
if (input_tensors.size() > 0) {
dim_post_expr = input_tensors[0].dim();
} else {
return result;
}
dim = op_plugin::utils::make_warp_dim(dim, dim_post_expr);
int64_t input_number = 0;
at_npu::native::OpCommand cmd;
cmd.Name("ConcatD");
for (size_t i = 0; i < input_tensors.size(); i++) {
if (input_tensors[i].numel() != 0) {
string input_name = "x" + std::to_string(input_number++);
cmd.Input(input_tensors[i], input_name);
}
}
cmd.Output(result)
.Attr("N", input_number)
.Attr("concat_dim", dim)
.Run();
return result;
}
}
at::Tensor& cat_out(at::TensorList tensors, at::Dimname dim, at::Tensor& result)
{
TORCH_CHECK(tensors.size() > 0, "cat inputs should not be empty." + OPS_ERROR(ErrCode::PARAM));
return at::cat_out(result, tensors, dimname_to_position(tensors[0], dim));
}
at::Tensor cat(at::TensorList tensors, at::Dimname dim)
{
TORCH_CHECK(tensors.size() > 0, "cat inputs should not be empty." + OPS_ERROR(ErrCode::PARAM));
return at::cat(tensors, dimname_to_position(tensors[0], dim));
}
at::Tensor& cat_out(const at::ITensorListRef& tensors, int64_t dim, at::Tensor& result)
{
auto materialized = tensors.materialize();
c10::SmallVector<at::Tensor, N> input_tensors = cat_dest_tensor_list(materialized);
int64_t dim_post_expr = 0;
if (input_tensors.size() > 0) {
dim_post_expr = input_tensors[0].dim();
} else {
at::Tensor output = npu_preparation::apply_tensor(materialized[0], result.options());
result.resize_({0}).copy_(output);
return result;
}
dim = op_plugin::utils::make_warp_dim(dim, dim_post_expr);
auto output_size = op_infer::cat_npu_output_size(input_tensors, dim);
npu_preparation::CheckOut(
{materialized[0].get()},
result,
ACL_FORMAT_ND,
materialized[0].get().scalar_type(),
output_size);
if (!npu_utils::check_match(&result)) {
at::Tensor contiguous_result = npu_utils::format_contiguous(result);
cat_output_nocheck(contiguous_result, materialized, dim);
npu_utils::format_fresh_view(result, contiguous_result);
} else {
cat_output_nocheck(result, materialized, dim);
}
return result;
}
at::Tensor cat(const at::ITensorListRef& tensors, int64_t dim)
{
auto materialized = tensors.materialize();
c10::SmallVector<at::Tensor, N> input_tensors = cat_dest_tensor_list(materialized);
int64_t dim_post_expr = 0;
if (input_tensors.size() > 0) {
dim_post_expr = input_tensors[0].dim();
} else {
at::Tensor result = npu_preparation::apply_tensor(materialized[0]);
return result;
}
dim = op_plugin::utils::make_warp_dim(dim, dim_post_expr);
auto output_size = op_infer::cat_npu_output_size(input_tensors, dim);
bool tensors_dim_check = true;
for (at::Tensor t : materialized) {
if (t.sizes().size() != 4) {
break;
}
int64_t C = t.size(1);
if (C % 16 != 0) {
tensors_dim_check = false;
break;
}
}
at::Tensor result = tensors_dim_check ?
npu_preparation::apply_tensor(input_tensors[0], output_size) :
npu_preparation::apply_tensor_with_format(input_tensors[0], output_size, ACL_FORMAT_ND);
cat_output_nocheck(result, materialized, dim);
return result;
}
}