* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "cross_aicpu.h"
#include <algorithm>
#include <chrono>
#include <iostream>
#include <iterator>
#include <numeric>
#include <vector>
#include "cpu_kernel_utils.h"
#include "utils/eigen_tensor.h"
#include "utils/kernel_util.h"
namespace {
const uint32_t kOutputNum = 1;
const uint32_t kInputNum = 2;
const int64_t KNum2 = 2;
const int64_t kParallelNum = static_cast<int64_t>(2 * 1024);
const char *const kCross = "Cross";
const int64_t kThreeNum = 3;
const int64_t kDimDef = static_cast<int64_t>(-65530);
#define CROSS_COMPUTE_CASE(DTYPE, TYPE, CTX) \
case (DTYPE): { \
uint32_t result = CrossBcastCompute<TYPE>(CTX); \
if (result != static_cast<uint32_t>(KERNEL_STATUS_OK)) { \
KERNEL_LOG_ERROR("Cross kernel compute failed."); \
return result; \
} \
break; \
}
}
namespace aicpu {
KernelStatus CrossCpuKernel::GetDimAndCheck(const CpuKernelContext &ctx, const BCalcInfo &calc_info) {
auto &shape_out = calc_info.shape_out;
int64_t rank = shape_out.size();
AttrValue *dim_attr = ctx.GetAttr("dim");
dim_ = (dim_attr == nullptr) ? kDimDef : dim_attr->GetInt();
if (dim_ == kDimDef) {
for (int64_t i = 0; i < rank; i++) {
if (shape_out[i] == kThreeNum) {
dim_ = i;
break;
}
if ((i == (rank - 1)) && (shape_out[i] != kThreeNum)) {
KERNEL_LOG_ERROR("The size of inputs dim should be 3,but got [%ld].",
shape_out[i]);
return KERNEL_STATUS_PARAM_INVALID;
}
}
}
if ((dim_ < (-rank)) || (dim_ > (rank - 1))) {
KERNEL_LOG_ERROR("dim should between [%ld] and [%ld],but got [%ld].",
-rank,
rank - 1, dim_);
return KERNEL_STATUS_PARAM_INVALID;
}
if (dim_ < 0) {
dim_ = rank + dim_;
}
if (shape_out[dim_] != kThreeNum) {
KERNEL_LOG_ERROR("The size of inputs dim should be 3,but got [%ld].",
shape_out[dim_]);
return KERNEL_STATUS_PARAM_INVALID;
}
return KERNEL_STATUS_OK;
}
uint32_t CrossCpuKernel::Compute(CpuKernelContext &ctx) {
KERNEL_HANDLE_ERROR(NormalCheck(ctx, kInputNum, kOutputNum),
"Cross check input and output number failed.");
auto output_type = ctx.Output(0)->GetDataType();
switch (output_type) {
CROSS_COMPUTE_CASE(DT_DOUBLE, double, ctx);
CROSS_COMPUTE_CASE(DT_FLOAT, float, ctx);
CROSS_COMPUTE_CASE(DT_FLOAT16, Eigen::half, ctx);
CROSS_COMPUTE_CASE(DT_INT8, int8_t, ctx);
CROSS_COMPUTE_CASE(DT_INT16, int16_t, ctx);
CROSS_COMPUTE_CASE(DT_INT32, int32_t, ctx);
CROSS_COMPUTE_CASE(DT_INT64, int64_t, ctx);
CROSS_COMPUTE_CASE(DT_UINT8, uint8_t, ctx);
CROSS_COMPUTE_CASE(DT_UINT16, uint16_t, ctx);
CROSS_COMPUTE_CASE(DT_UINT32, uint32_t, ctx);
CROSS_COMPUTE_CASE(DT_UINT64, uint64_t, ctx);
CROSS_COMPUTE_CASE(DT_COMPLEX64, std::complex<float>, ctx);
CROSS_COMPUTE_CASE(DT_COMPLEX128, std::complex<double>, ctx);
default:
KERNEL_LOG_ERROR("Cross kernel data type [%s] not support.",
DTypeStr(output_type).c_str());
return static_cast<uint32_t>(KERNEL_STATUS_PARAM_INVALID);
}
return static_cast<uint32_t>(KERNEL_STATUS_OK);
}
std::vector<int64_t> CrossCpuKernel::StridesCompute(
int64_t stride_length, std::vector<int64_t> data_dims) {
std::vector<int64_t> stride;
int64_t stride_tmp = 1;
for (int64_t i = stride_length - static_cast<int64_t>(1); i > static_cast<int64_t>(-1); i--) {
(void)stride.insert(stride.begin(), stride_tmp);
stride_tmp *= data_dims[i];
}
return stride;
}
void CrossCpuKernel::InitStride(const BCalcInfo &calc_info) {
const int64_t n = calc_info.shape_out.size();
stride_ = StridesCompute(n, calc_info.shape_out);
data_stride_ = stride_[dim_];
}
template <typename T1>
void CrossCpuKernel::SetResult(const BCalcInfo &calc_info,
int64_t data_start_bcast) const {
auto input1_data_addr = static_cast<const T1 *>(calc_info.input_0->GetData());
auto input2_data_addr = static_cast<const T1 *>(calc_info.input_1->GetData());
auto output_data_addr = static_cast<T1 *>(calc_info.output->GetData());
int64_t input1_data_start = calc_info.x_indexes[data_start_bcast];
int64_t input2_data_start = calc_info.y_indexes[data_start_bcast];
int64_t input1_data_stride1 = calc_info.x_indexes[data_start_bcast + data_stride_];
int64_t input1_data_stride2 = calc_info.x_indexes[data_start_bcast + KNum2 * data_stride_];
int64_t input2_data_stride1 = calc_info.y_indexes[data_start_bcast + data_stride_];
int64_t input2_data_stride2 = calc_info.y_indexes[data_start_bcast + KNum2 * data_stride_];
output_data_addr[data_start_bcast] =
input1_data_addr[input1_data_stride1] * input2_data_addr[input2_data_stride2] -
input1_data_addr[input1_data_stride2] * input2_data_addr[input2_data_stride1];
output_data_addr[data_start_bcast + data_stride_] =
input1_data_addr[input1_data_stride2] * input2_data_addr[input2_data_start] -
input1_data_addr[input1_data_start] * input2_data_addr[input2_data_stride2];
output_data_addr[data_start_bcast + KNum2 * data_stride_] =
input1_data_addr[input1_data_start] * input2_data_addr[input2_data_stride1] -
input1_data_addr[input1_data_stride1] * input2_data_addr[input2_data_start];
}
template <typename T1>
void CrossCpuKernel::ParallelCompute(const BCalcInfo &calc_info, int64_t start, int64_t end) {
auto &data_shape = calc_info.shape_out;
const int64_t data_dim = data_shape.size();
std::vector<int64_t> position_in_dims(data_dim);
int64_t index_in_curr_dim = start;
int64_t data_start = 0;
for (int64_t i = 0; i < data_dim; i++) {
if (i == dim_) {
continue;
}
position_in_dims[i] = index_in_curr_dim % data_shape[i];
data_start += (index_in_curr_dim % data_shape[i]) * stride_[i];
index_in_curr_dim = index_in_curr_dim / data_shape[i];
}
while (start < end) {
SetResult<T1>(calc_info, data_start);
start++;
for (int64_t i = 0; i < data_dim; i++) {
if (i == dim_) {
continue;
}
position_in_dims[i]++;
data_start += stride_[i];
if ((position_in_dims[i] == data_shape[i]) && (i != data_dim - 1)) {
data_start -= position_in_dims[i] * stride_[i];
position_in_dims[i] = 0;
} else {
break;
}
}
}
}
template <typename T1>
uint32_t CrossCpuKernel::CrossCompute(const CpuKernelContext &ctx, const BCalcInfo &calc_info) {
KERNEL_HANDLE_ERROR(static_cast<uint32_t>(GetDimAndCheck(ctx, calc_info)),
"[%s] check params failed.", kCross);
int64_t total = ctx.Output(0)->NumElements() / 3;
InitStride(calc_info);
auto cross_shard = [this, &calc_info](int64_t start, int64_t end) {
ParallelCompute<T1>(calc_info, start, end);
};
if (total > kParallelNum) {
const int64_t max_core_num = std::max(
static_cast<int64_t>(1),
static_cast<int64_t>(aicpu::CpuKernelUtils::GetCPUNum(ctx) - 2));
const int64_t per_unit_size = total / std::min(total, max_core_num);
KERNEL_HANDLE_ERROR(
CpuKernelUtils::ParallelFor(ctx, total, per_unit_size, cross_shard),
"Cross compute failed");
} else {
cross_shard(0, total);
}
return static_cast<uint32_t>(KERNEL_STATUS_OK);
}
template <typename T1>
uint32_t CrossCpuKernel::CrossBcastCompute(const CpuKernelContext &ctx) {
BCalcInfo calc_info;
calc_info.input_0 = ctx.Input(0);
calc_info.input_1 = ctx.Input(1);
calc_info.output = ctx.Output(0);
DataType input0_type = calc_info.input_0->GetDataType();
DataType input1_type = calc_info.input_1->GetDataType();
KERNEL_CHECK_FALSE((input0_type == input1_type), KERNEL_STATUS_PARAM_INVALID,
"DataType of x1 [%d] should be same as x2 [%d].",
static_cast<int32_t>(input0_type), static_cast<int32_t>(input1_type))
KERNEL_LOG_INFO("CpuKernel[%s], input x1 : addr[%p], size[%lu];"
"input x2: addr[%p], size[%lu];"
"output: addr[%p], size[%lu].",
ctx.GetOpType().c_str(), calc_info.input_0->GetData(),
calc_info.input_0->GetDataSize(), calc_info.input_1->GetData(),
calc_info.input_1->GetDataSize(), calc_info.output->GetData(),
calc_info.output->GetDataSize());
Bcast bcast;
KERNEL_HANDLE_ERROR(bcast.GenerateBcastInfo(calc_info),
"Generate broadcast info failed.")
bcast.BCastIndexes(calc_info.x_indexes, calc_info.y_indexes);
bcast.GetBcastVec(calc_info);
return CrossCompute<T1>(ctx, calc_info);
}
REGISTER_CPU_KERNEL(kCross, CrossCpuKernel);
}