* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "concat_reg_api_call.h"
#include "api_call/utils/api_call_factory.h"
#include "ascir_ops.h"
#include "ascir_utils.h"
namespace codegen {
Status ConcatRegApiCall::ParseAttr(const ascir::NodeView &node) {
node_ = node;
return ge::SUCCESS;
}
Status ConcatRegApiCall::Generate(const TPipe &tpipe, const std::vector<ascir::AxisId> ¤t_axis,
const vector<std::reference_wrapper<const Tensor>> &inputs,
const vector<std::reference_wrapper<const Tensor>> &outputs,
string &result) const {
(void) current_axis;
GE_CHK_BOOL_RET_STATUS((!inputs.empty()) && (!outputs.empty()), ge::FAILED,
"Codegen input or output tensor is empty");
const auto &x0 = inputs[0].get();
const auto &y = outputs[0].get();
size_t concat_dim;
GE_ASSERT_SUCCESS(ParseConcatDim(x0, y, concat_dim), "Failed to parse concat dim");
std::stringstream ss;
if (CanConcatOneAxis(inputs, y)) {
GE_ASSERT_SUCCESS(GenerateForOneAxis(inputs, y, ss));
result = ss.str();
return ge::SUCCESS;
}
ConcatTiling concat_tiling;
GE_ASSERT_SUCCESS(InitializeTiling(concat_dim, inputs, y, concat_tiling));
if (IsAllAligned(concat_tiling, concat_tiling.src_col_size_exprs)) {
GE_ASSERT_SUCCESS(GenerateForAllAligned(inputs, y, concat_tiling, tpipe.tiler, ss));
} else {
int64_t life_time_axis_id = -1L;
int64_t id = -1L;
auto it = this->tmp_buf_id.find(life_time_axis_id);
GE_ASSERT_TRUE(it != this->tmp_buf_id.end(), "ConcatRegApiCall cannot find tmp buffer id to use.");
id = it->second;
GE_ASSERT_SUCCESS(CanUseGather(concat_tiling));
if (concat_tiling.can_use_gather && concat_tiling.dst_col_size_expr.IsConstExpr()) {
GE_ASSERT_SUCCESS(GenerateForGather(inputs, y, concat_tiling, tpipe, ss, id));
} else {
GE_ASSERT_SUCCESS(GenerateDefault(inputs, y, concat_tiling, tpipe, ss, id));
}
}
result = ss.str();
return ge::SUCCESS;
}
bool ConcatRegApiCall::IsShareInputs() const {
std::set<int64_t> queue_ids;
for (uint32_t i = 0; i < node_->inputs.Size(); ++i) {
queue_ids.emplace(node_->inputs[i].attr.que.id);
}
return queue_ids.size() == 1UL;
}
bool ConcatRegApiCall::AreContiguousBufsPreferred() const {
GE_CHK_BOOL_RET_SPECIAL_STATUS(IsTile(), false, "%s all inputs are from single source");
const auto is_all_inputs_shape_equal = ascir::utils::AreConcatInputShapesEqual(node_);
GE_CHK_BOOL_RET_SPECIAL_STATUS((is_all_inputs_shape_equal == af::TriBool::kFalse), false,
"%s cannot use Gather, input shapes differ", node_->GetNamePtr());
GELOGD("%s may use Gather, contiguous input bufs are preferred", node_->GetNamePtr());
return true;
}
bool ConcatRegApiCall::CanConcatOneAxis(const std::vector<std::reference_wrapper<const Tensor>> &inputs,
const Tensor &y) {
constexpr int64_t kVecLen = 256;
auto data_type_size = ge::GetSizeByDataType(y.dtype);
GE_ASSERT_TRUE(data_type_size > 0);
const auto max_size = kVecLen / data_type_size;
bool concat_one_axis = false;
if ((y.vectorized_strides.size() == 2UL) && (y.vectorized_strides[0] == af::ops::Zero)) {
for (const auto &input : inputs) {
const auto &x = input.get();
GE_WARN_ASSERT(x.vectorized_axis.size() == y.vectorized_strides.size());
auto pos = x.vectorized_axis_pos.back();
auto axis_size_expr = x.axis_size[pos];
int64_t axis_size = std::numeric_limits<int64_t>::max();
if (axis_size_expr.IsConstExpr()) {
GE_WARN_ASSERT(axis_size_expr.GetConstValue(axis_size));
}
GE_CHK_BOOL_RET_STATUS_NOLOG((axis_size <= max_size), false);
}
concat_one_axis = true;
}
return concat_one_axis;
}
ge::Status ConcatRegApiCall::GenerateDefault(const vector<std::reference_wrapper<const Tensor>> &inputs,
const Tensor &y, const ConcatApiCall::ConcatTiling &tiling,
const TPipe &t_pipe, std::stringstream &ss, const int64_t tmp_buf_id) {
std::string dtype_name;
(void)Tensor::DtypeName(y.dtype, dtype_name);
if (tiling.data_type_size == sizeof(uint64_t)) {
const ConcatTiling tiling_b32 = B64ToB32(tiling);
DefineConcatTiling(tiling_b32, t_pipe.tiler, ss);
dtype_name = "uint32_t";
} else if (NeedB8ToB16(tiling)) {
GELOGD("can use b16 concat", dtype_name.c_str());
const ConcatTiling tiling_b16 = B8ToB16(tiling);
DefineConcatTiling(tiling_b16, t_pipe.tiler, ss);
dtype_name = "uint16_t";
} else {
DefineConcatTiling(tiling, t_pipe.tiler, ss);
}
NormalizeDtype(dtype_name);
GenSrcAddrs(inputs, dtype_name, ss);
if (tiling.can_use_gather) {
GELOGD("use ConcatExtendDyn");
ss << "concat::ConcatExtendDyn<";
} else {
GELOGD("use ConcatExtend");
ss << "concat::ConcatExtend<";
}
ss << dtype_name << ", " << inputs.size();
if (tiling.can_use_gather && tiling.all_inputs_shape_equal == af::TriBool::kUnknown) {
ss << ", " << "true";
}
ss << ">("
<< "(" << dtype_name << " *)" << y << ".GetPhyAddr()"
<< ", " << "concat_src_addrs, " << t_pipe.tmp_buf << "_" << std::to_string(tmp_buf_id) << ", concat_tiling);"
<< std::endl;
return ge::SUCCESS;
}
ge::Status ConcatRegApiCall::GenerateForGather(const vector<std::reference_wrapper<const Tensor>> &inputs,
const Tensor &y, const ConcatApiCall::ConcatTiling &tiling,
const TPipe &t_pipe, std::stringstream &ss, const int64_t tmp_buf_id) {
std::string dtype_name;
(void) Tensor::DtypeName(y.dtype, dtype_name);
if (tiling.data_type_size == sizeof(uint64_t)) {
const ConcatTiling tiling_b32 = B64ToB32(tiling);
DefineConcatTilingGather(tiling_b32, t_pipe.tiler, ss);
dtype_name = "uint32_t";
} else if (NeedB8ToB16(tiling)) {
GELOGD("can use b16 concat", dtype_name.c_str());
const ConcatTiling tiling_b16 = B8ToB16(tiling);
DefineConcatTilingGather(tiling_b16, t_pipe.tiler, ss);
dtype_name = "uint16_t";
} else {
DefineConcatTilingGather(tiling, t_pipe.tiler, ss);
}
NormalizeDtype(dtype_name);
GenSrcAddrs(inputs, dtype_name, ss);
ss << "concat::ConcatExtendByGather<" << dtype_name << ", " << inputs.size() << ">("
<< "(" << dtype_name << " *)" << y << ".GetPhyAddr()"
<< ", " << "concat_src_addrs, "
<< t_pipe.tmp_buf << "_" << std::to_string(tmp_buf_id)
<< ", concat_tiling);" << std::endl;
GELOGD("use ConcatExtendByGather");
return ge::SUCCESS;
}
bool ConcatRegApiCall::NeedB8ToB16(const ConcatApiCall::ConcatTiling &tiling) {
if (tiling.data_type_size != sizeof(uint8_t)) {
return false;
}
for (size_t i = 0; i < tiling.is_padded.size(); i++) {
auto &col_size = tiling.is_padded[i] ? tiling.last_dim_size_exprs[i] : tiling.src_col_size_exprs[i];
if (af::sym::Mod(col_size, af::Symbol(sizeof(uint16_t))) != af::ops::Zero) {
return false;
}
}
return true;
}
ConcatApiCall::ConcatTiling ConcatRegApiCall::B64ToB32(const ConcatTiling &tiling) {
auto kB64ToB32 = af::Symbol(sizeof(uint64_t) / sizeof(uint32_t));
ConcatTiling tiling_b32 = tiling;
tiling_b32.total_rows_expr = tiling.total_rows_expr;
tiling_b32.dst_col_size_expr = tiling.dst_col_size_expr * kB64ToB32;
for (auto &src_col_size : tiling_b32.src_col_size_exprs) {
src_col_size = src_col_size * kB64ToB32;
}
for (auto &src_row_stride : tiling_b32.src_row_strides) {
src_row_stride = src_row_stride * kB64ToB32;
}
for (auto &src_non_zero_stride : tiling_b32.src_non_zero_strides) {
src_non_zero_stride = src_non_zero_stride * kB64ToB32;
}
for (auto &last_dim_size_expr : tiling_b32.last_dim_size_exprs) {
last_dim_size_expr = last_dim_size_expr * kB64ToB32;
}
return tiling_b32;
}
ConcatApiCall::ConcatTiling ConcatRegApiCall::B8ToB16(const ConcatTiling &tiling) {
ConcatTiling tiling_b16 = tiling;
const auto &kB16ToB8 = af::Symbol(2);
tiling_b16.total_rows_expr = tiling.total_rows_expr;
tiling_b16.dst_col_size_expr = tiling.dst_col_size_expr / kB16ToB8;
for (auto &src_col_size : tiling_b16.src_col_size_exprs) {
src_col_size = src_col_size / kB16ToB8;
}
for (auto &src_row_stride : tiling_b16.src_row_strides) {
src_row_stride = src_row_stride / kB16ToB8;
}
for (auto &src_non_zero_stride : tiling_b16.src_non_zero_strides) {
src_non_zero_stride = src_non_zero_stride / kB16ToB8;
}
for (auto &last_dim_size_expr : tiling_b16.last_dim_size_exprs) {
last_dim_size_expr = last_dim_size_expr / kB16ToB8;
}
return tiling_b16;
}
ge::Status ConcatRegApiCall::CanUseGather(ConcatTiling &tiling) const {
GE_CHK_BOOL_RET_SPECIAL_STATUS(tiling.any_padded, ge::SUCCESS, "cannot use Gather: input is padded");
if (IsTile()) {
tiling.all_inputs_shape_equal = af::TriBool::kTrue;
} else {
GE_CHK_BOOL_RET_SPECIAL_STATUS((!is_input_tbuf_contiguous), ge::SUCCESS,
"cannot use Gather: input bufs cannot be contiguous");
tiling.all_inputs_shape_equal = ascir::utils::AreConcatInputShapesEqual(node_);
}
if (tiling.src_col_size_exprs[0].IsConstExpr()) {
uint32_t src_col_size = 0;
GE_ASSERT_TRUE(tiling.src_col_size_exprs[0].GetConstValue(src_col_size));
constexpr uint32_t kMaxSrcSize = 256U / 2U;
if (src_col_size * tiling.data_type_size > kMaxSrcSize) {
GELOGD("src col size = %u, over %u, cannot use Gather", src_col_size * tiling.data_type_size,
kMaxSrcSize);
return ge::SUCCESS;
}
}
tiling.can_use_gather = true;
return ge::SUCCESS;
}
bool ConcatRegApiCall::IsTile() const {
std::set<const af::OutDataAnchor *> src_anchors;
for (const auto &node_and_out_anchor : node_->GetInDataNodesAndAnchors()) {
src_anchors.emplace(node_and_out_anchor.second.get());
}
const bool is_tile = (src_anchors.size() == 1UL);
GELOGI("is tile by concat case = %d", static_cast<int32_t>(is_tile));
return is_tile;
}
void ConcatRegApiCall::NormalizeDtype(std::string &dtype_name) {
if (dtype_name == "int8_t") {
dtype_name = "uint8_t";
}
}
std::string ConcatRegApiCall::GetTilingDataType(const ConcatTiling &tiling) {
std::string tiling_data_type = "concat::ConcatTiling";
if (tiling.any_padded) {
tiling_data_type += "Padded";
}
return tiling_data_type;
}
void ConcatRegApiCall::DefineConcatTiling(const ConcatTiling &tiling, const Tiler &tiler, std::stringstream &ss) {
auto tiling_data_type = GetTilingDataType(tiling);
ss << "const " << tiling_data_type << "<" << tiling.src_col_size_exprs.size() << "> concat_tiling {" << std::endl;
ss << " .num_rows = static_cast<uint32_t>(" << tiler.ActualSize(tiling.total_rows_expr) << ")," << std::endl;
ss << " .num_dst_cols = " << tiler.Size(tiling.dst_col_size_expr, true) << "," << std::endl;
ss << " .num_srcs_cols = {";
for (const auto &src_col_size : tiling.src_col_size_exprs) {
ss << tiler.Size(src_col_size, true) << ", ";
}
ss << "}," << std::endl;
if (tiling.any_padded) {
ss << " .src_row_strides = {";
for (const auto &src_row_stride : tiling.src_row_strides) {
ss << tiler.Size(src_row_stride, true) << ", ";
}
ss << "}," << std::endl;
ss << " .src_second_last_dim_strides = {";
for (size_t i = 0UL; i < tiling.src_non_zero_strides.size(); ++i) {
auto stride = tiling.is_padded[i] ? tiler.Size(tiling.src_non_zero_strides[i]) : "0";
ss << stride << ", ";
}
ss << "}," << std::endl;
ss << " .gather_mask_dim_sizes = {";
for (size_t i = 0UL; i < tiling.last_dim_size_exprs.size(); ++i) {
auto dim_size = tiling.is_padded[i] ? tiler.Size(tiling.last_dim_size_exprs[i]) : "0";
ss << dim_size << ", ";
}
ss << "}," << std::endl;
}
ss << "};" << std::endl;
}
void ConcatRegApiCall::DefineConcatTilingGather(const ConcatTiling &tiling, const Tiler &tiler, std::stringstream &ss) {
std::string tiling_data_type = "concat::ConcatByGatherTiling";
ss << "const " << tiling_data_type << " concat_tiling {" << std::endl;
ss << " .num_rows = static_cast<uint32_t>(" << tiler.ActualSize(tiling.total_rows_expr) << ")," << std::endl;
ss << " .num_dst_cols = " << tiler.Size(tiling.dst_col_size_expr, true) << "," << std::endl;
ss << " .num_src_cols = " << tiler.Size(tiling.src_col_size_exprs[0], true) << "," << std::endl;
ss << "};" << std::endl;
}
void ConcatRegApiCall::GenSrcAddrs(const vector<std::reference_wrapper<const Tensor>> &inputs,
const string &dtype_name,
std::stringstream &ss) {
ss << dtype_name << " *concat_src_addrs[] { ";
for (auto &input : inputs) {
const auto &x = input.get();
ss << "(" << dtype_name << " *)" << x << ".GetPhyAddr(), ";
}
ss << "};" << std::endl;
}
Status ConcatRegApiCall::GenerateForOneAxis(const vector<std::reference_wrapper<const Tensor>> &inputs, const Tensor &y,
std::stringstream &ss) {
std::string dtype_name;
GE_CHK_STATUS_RET(Tensor::DtypeName(y.dtype, dtype_name), "Codegen get data type:%d failed",
static_cast<int32_t>(y.dtype));
ss << "constexpr concat::ConcatTilingOneAxis<" << inputs.size() << "> concat_tiling {" << std::endl;
ss << " .src_col_sizes = { ";
std::vector<uint32_t> dst_col_offsets;
uint32_t dst_col_offset = 0U;
for (const auto &input : inputs) {
const auto &x = input.get();
auto pos = x.vectorized_axis_pos[1];
auto &axis_size = x.axis_size[pos];
GE_ASSERT_TRUE(axis_size.IsConstExpr());
uint32_t src_col_size;
GE_ASSERT_TRUE(axis_size.GetConstValue(src_col_size));
ss << src_col_size << ", ";
dst_col_offsets.push_back(dst_col_offset);
dst_col_offset += src_col_size;
}
ss << "}," << std::endl;
ss << " .dst_col_offsets = { ";
for (const auto offset : dst_col_offsets) {
ss << offset << ", ";
}
ss << "}," << std::endl;
ss << "};" << std::endl;
GenSrcAddrs(inputs, dtype_name, ss);
ss << "concat::ConcatOneAxis<" << dtype_name << ", " << inputs.size() << ">("
<< "(" << dtype_name << " *)" << y << ".GetPhyAddr()"
<< ", " << "concat_src_addrs, concat_tiling);" << std::endl;
return ge::SUCCESS;
}
[[maybe_unused]] static ApiCallRegister<ConcatRegApiCall> register_concat_api_call("ConcatRegApiCall");
}