* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "concat_v2_aicpu.h"
#include <complex>
#include <unordered_map>
using namespace std;
namespace {
const char* const ConcatV2 = "ConcatV2";
}
namespace aicpu {
uint32_t ConcatV2CpuKernel::CheckConcatV2Params(const CpuKernelContext& ctx)
{
AttrValue* n_ptr = ctx.GetAttr("N");
KERNEL_CHECK_NULLPTR(n_ptr, KERNEL_STATUS_PARAM_INVALID, "Get attr N failed.");
n_ = n_ptr->GetInt();
KERNEL_CHECK_FALSE(
(n_ >= 1), KERNEL_STATUS_PARAM_INVALID, "Attr N must >= 1, but got attr N[%ld]", n_);
uint32_t input_num = ctx.GetInputsSize();
KERNEL_CHECK_FALSE(
(static_cast<int64_t>(input_num) - 1 == n_), KERNEL_STATUS_PARAM_INVALID,
"Input num must equal attr N[%ld] + 1, but got input num[%u]", n_, input_num);
return KERNEL_STATUS_OK;
}
uint32_t ConcatV2CpuKernel::ParseConcatDim(const CpuKernelContext& ctx, int64_t& concat_dim) const
{
auto concat_dim_ptr = ctx.Input(static_cast<uint32_t>(n_));
KERNEL_CHECK_NULLPTR(concat_dim_ptr, KERNEL_STATUS_PARAM_INVALID, "Get input concat_dim failed.");
auto shape_ptr = concat_dim_ptr->GetTensorShape();
KERNEL_CHECK_NULLPTR(shape_ptr, KERNEL_STATUS_PARAM_INVALID, "Get input concat_dim shape failed.");
int32_t dims = shape_ptr->GetDims();
KERNEL_CHECK_FALSE(
(dims == 0) || ((dims == 1) && (shape_ptr->NumElements() == 1)),
KERNEL_STATUS_PARAM_INVALID, "Input concat_dim should be a scalar integer, but got rank[%d].", dims);
auto dtype = concat_dim_ptr->GetDataType();
auto data_raw = concat_dim_ptr->GetData();
KERNEL_CHECK_NULLPTR(data_raw, KERNEL_STATUS_PARAM_INVALID, "Get input concat_dim data failed.");
if (dtype == DT_INT32) {
int32_t tmp = 0;
KERNEL_CHECK_FALSE(
memcpy_s(&tmp, sizeof(tmp), data_raw, sizeof(tmp)) == EOK,
KERNEL_STATUS_PARAM_INVALID, "memcpy concat_dim(int32) failed.");
concat_dim = static_cast<int64_t>(tmp);
return KERNEL_STATUS_OK;
}
if (dtype == DT_INT64) {
KERNEL_CHECK_FALSE(
memcpy_s(&concat_dim, sizeof(concat_dim), data_raw, sizeof(concat_dim)) == EOK,
KERNEL_STATUS_PARAM_INVALID, "memcpy concat_dim(int64) failed.");
return KERNEL_STATUS_OK;
}
KERNEL_LOG_ERROR("Unsupported concat_dim data type: %d", dtype);
return KERNEL_STATUS_PARAM_INVALID;
}
uint32_t ConcatV2CpuKernel::InitConcatV2Params(const CpuKernelContext& ctx)
{
int64_t concat_dim = 0;
uint32_t ret = ParseConcatDim(ctx, concat_dim);
KERNEL_CHECK_FALSE((ret == KERNEL_STATUS_OK), ret, "ParseConcatDim failed.");
auto input0_ptr = ctx.Input(0);
KERNEL_CHECK_NULLPTR(input0_ptr, KERNEL_STATUS_PARAM_INVALID, "Get input x0 failed.");
auto input0_shape_ptr = input0_ptr->GetTensorShape();
KERNEL_CHECK_NULLPTR(input0_shape_ptr, KERNEL_STATUS_PARAM_INVALID, "Get input x0 shape failed.");
input_dims_ = input0_shape_ptr->GetDims();
data_type_ = input0_ptr->GetDataType();
KERNEL_LOG_INFO("ConcatV2 init: data_type=%d, input_dims=%d.", data_type_, input_dims_);
axis_ = concat_dim < 0 ? concat_dim + input_dims_ : concat_dim;
KERNEL_CHECK_FALSE(
(axis_ >= 0 && axis_ < input_dims_), KERNEL_STATUS_PARAM_INVALID,
"Input concat_dim need in range[%d, %d), but got %ld.", -input_dims_, input_dims_, concat_dim);
inputs_flat_dim0_ = 1;
for (int32_t d = 0; d < axis_; ++d) {
inputs_flat_dim0_ *= input0_shape_ptr->GetDimSize(d);
}
return KERNEL_STATUS_OK;
}
uint32_t ConcatV2CpuKernel::ValidateInputShape(
Tensor* input_i_ptr, TensorShape* input0_shape_ptr, int64_t i) const
{
auto shape_ptr = input_i_ptr->GetTensorShape();
KERNEL_CHECK_NULLPTR(shape_ptr, KERNEL_STATUS_PARAM_INVALID, "Get input x[%ld] shape failed.", i);
int32_t dims = shape_ptr->GetDims();
KERNEL_CHECK_FALSE(
(dims == input_dims_), KERNEL_STATUS_PARAM_INVALID,
"Ranks of inputs should match: shape[0]=%d vs. shape[%ld]=%d", input_dims_, i, dims);
for (int32_t j = 0; j < input_dims_; ++j) {
if (j == axis_) {
continue;
}
int64_t dim_0j = input0_shape_ptr->GetDimSize(j);
int64_t dim_ij = shape_ptr->GetDimSize(j);
KERNEL_CHECK_FALSE(
(dim_0j == dim_ij), KERNEL_STATUS_PARAM_INVALID,
"Dim mismatch at axis %d: shape[0]=%ld vs. shape[%ld]=%ld", j, dim_0j, i, dim_ij);
}
return KERNEL_STATUS_OK;
}
uint32_t ConcatV2CpuKernel::CheckAndInitParams(const CpuKernelContext& ctx)
{
KERNEL_HANDLE_ERROR(this->CheckConcatV2Params(ctx), "CheckConcatV2Params failed.");
KERNEL_HANDLE_ERROR(this->InitConcatV2Params(ctx), "InitConcatV2Params failed.");
return KERNEL_STATUS_OK;
}
uint32_t ConcatV2CpuKernel::Compute(CpuKernelContext& ctx)
{
KERNEL_LOG_INFO("%s start.", ConcatV2);
uint32_t ret = CheckAndInitParams(ctx);
KERNEL_CHECK_FALSE(
(ret == KERNEL_STATUS_OK), KERNEL_STATUS_PARAM_INVALID, "CheckAndInitParams failed, ret=[%u].", ret);
using ComputeFunc = uint32_t (ConcatV2CpuKernel::*)(CpuKernelContext&);
static const std::unordered_map<DataType, ComputeFunc> calls_map = {
{DT_FLOAT16, &ConcatV2CpuKernel::DoCompute<Eigen::half>},
{DT_FLOAT, &ConcatV2CpuKernel::DoCompute<float>},
{DT_INT8, &ConcatV2CpuKernel::DoCompute<int8_t>},
{DT_INT16, &ConcatV2CpuKernel::DoCompute<int16_t>},
{DT_INT32, &ConcatV2CpuKernel::DoCompute<int32_t>},
{DT_INT64, &ConcatV2CpuKernel::DoCompute<int64_t>},
{DT_UINT8, &ConcatV2CpuKernel::DoCompute<uint8_t>},
{DT_UINT16, &ConcatV2CpuKernel::DoCompute<uint16_t>},
{DT_UINT32, &ConcatV2CpuKernel::DoCompute<uint32_t>},
{DT_UINT64, &ConcatV2CpuKernel::DoCompute<uint64_t>},
{DT_BOOL, &ConcatV2CpuKernel::DoCompute<bool>},
{DT_DOUBLE, &ConcatV2CpuKernel::DoCompute<double>},
{DT_COMPLEX64, &ConcatV2CpuKernel::DoCompute<std::complex<float>>},
{DT_COMPLEX128, &ConcatV2CpuKernel::DoCompute<std::complex<double>>},
{DT_BFLOAT16, &ConcatV2CpuKernel::DoCompute<Eigen::bfloat16>},
};
auto iter = calls_map.find(data_type_);
if (iter == calls_map.end()) {
KERNEL_LOG_ERROR("Unsupported datatype[%d]", data_type_);
return KERNEL_STATUS_PARAM_INVALID;
}
ComputeFunc fn = iter->second;
uint32_t result = (this->*fn)(ctx);
if (result == KERNEL_STATUS_OK) {
KERNEL_LOG_INFO("%s success.", ConcatV2);
} else {
KERNEL_LOG_ERROR("%s failed, result=[%u].", ConcatV2, result);
}
return result;
}
REGISTER_CPU_KERNEL(ConcatV2, ConcatV2CpuKernel);
}