#include <ATen/NamedTensorUtils.h>
#include <ATen/native/TypeProperties.h>
#include "op_plugin/AclOpsInterface.h"
#include "op_plugin/OpApiInterface.h"
#include "op_plugin/utils/op_api_common.h"
namespace op_api {
using npu_preparation = at_npu::native::OpPreparation;
using npu_utils = at_npu::native::NpuUtils;
bool all_contiguous(at::TensorList tensors)
{
for (const auto& t : tensors) {
if (!t.is_contiguous()) {
return false;
}
}
return true;
}
bool isSupportedDtype(at::ScalarType dtype)
{
return dtype == at::ScalarType::Float ||
dtype == at::ScalarType::Half ||
dtype == at::ScalarType::BFloat16;
}
bool isSupportedDtypeCombination(at::ScalarType inputDtype, at::ScalarType outputDtype)
{
if (!isSupportedDtype(inputDtype) || !isSupportedDtype(outputDtype)) {
return false;
}
if (inputDtype == at::ScalarType::Float) {
return outputDtype == at::ScalarType::Float;
}
return true;
}
static c10::SmallVector<int64_t, op_infer::SIZE> get_chunk_cat_out_sizes(
at::TensorList tensors,
int64_t dim,
int64_t num_chunks)
{
auto first_tensor = tensors[0];
auto first_sizes = first_tensor.sizes();
int64_t input_dim_num = first_sizes.size();
c10::SmallVector<int64_t, op_infer::SIZE> out_shape;
out_shape.resize((dim + 1) + 1);
for (int64_t j = 0; j < dim; j++) {
out_shape[j] = first_sizes[j];
}
out_shape[dim] = num_chunks;
int64_t outputCol = 0;
for (const auto& tensor : tensors)
{
auto sizes = tensor.sizes();
int64_t tensor_num_dim = sizes.size();
int64_t chunkDimSize = sizes[dim];
if (num_chunks>0) {
int64_t chunkCol = (chunkDimSize + num_chunks - 1) / num_chunks;
int64_t dim1Size = chunkCol;
for (int64_t j = dim + 1; j < tensor_num_dim; j++) {
dim1Size *= sizes[j];
}
outputCol += dim1Size;
}
}
out_shape[dim + 1] = outputCol;
return out_shape;
}
at::Tensor _chunk_cat( at::TensorList tensors, int64_t dim, int64_t num_chunks)
{
static bool npu_support_aclnn = check_aclnn_kernel_available("aclnnChunkCat")
&& c10_npu::GetSocVersion() < c10_npu::SocVersion::Ascend950;
if (!npu_support_aclnn) {
return at::native::_chunk_cat(tensors, dim, num_chunks);
}
auto view_sizes = get_chunk_cat_out_sizes(tensors, dim, num_chunks);
at::Tensor result = npu_preparation::apply_tensor_without_format(view_sizes, tensors[0].scalar_type());
if (all_contiguous(tensors) &&
isSupportedDtype(tensors[0].scalar_type()) &&
dim == 0) {
EXEC_NPU_CMD(aclnnChunkCat, tensors, dim, num_chunks, result);
} else {
at::native::_chunk_cat_out(tensors, dim, num_chunks, result);
}
return result;
}
at::Tensor& _chunk_cat_out( at::TensorList tensors, int64_t dim, int64_t num_chunks, at::Tensor& out)
{
static bool npu_support_aclnn = check_aclnn_kernel_available("aclnnChunkCat")
&& c10_npu::GetSocVersion() < c10_npu::SocVersion::Ascend950;
if (!npu_support_aclnn) {
return at::native::_chunk_cat_out(tensors, dim, num_chunks, out);
}
TORCH_CHECK(
tensors[0].device() == out.device(),
"_chunk_cat_out: mismatch between input and out tensor devices");
bool both_input_output_contiguous = all_contiguous(tensors) && out.is_contiguous();
auto view_sizes = get_chunk_cat_out_sizes(tensors, dim, num_chunks);
npu_preparation::check_tensor({tensors[0]}, out, at::IntArrayRef(view_sizes));
if (both_input_output_contiguous &&
isSupportedDtypeCombination(tensors[0].scalar_type(), out.scalar_type()) &&
dim == 0) {
EXEC_NPU_CMD(aclnnChunkCat, tensors, dim, num_chunks, out);
} else {
at::native::_chunk_cat_out(tensors, dim, num_chunks, out);
}
return out;
}
}