* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "concat_api_call.h"
#include "common_utils.h"
#include "graph/symbolizer/symbolic_utils.h"
#include "ascir_ops.h"
#include "common/checker.h"
#include "../utils/api_call_factory.h"
namespace codegen {
using namespace ascgen_utils;
namespace {
constexpr uint32_t kDataBlockSize = 32;
constexpr uint32_t kAddrListSize = 16;
constexpr size_t kOffsetSecondLastStride = 2;
Status FindNonZeroStride(const std::vector<ascir::SizeExpr> &vectorized_strides,
int32_t index,
af::Expression &stride) {
for (int32_t i = index; i >= 0; --i) {
stride = vectorized_strides[i];
if (af::SymbolicUtils::StaticCheckEq(stride, af::ops::Zero) != af::TriBool::kTrue) {
break;
}
}
GE_ASSERT_TRUE(stride != af::ops::Zero,
"Failed to find non-zero v_stride before index = %d, v_strides = %s",
index, af::ToString(vectorized_strides).c_str());
return ge::SUCCESS;
}
}
Status ConcatApiCall::ParseAttr(const ascir::NodeView &node) {
GE_ASSERT_SUCCESS(ApiCall::ParseAttr(node));
(void) af::AttrUtils::GetBool(node->GetOpDesc(), "_concat_small_tail", use_concat_small_tail_api_);
node_ = node;
return ge::SUCCESS;
}
Status ConcatApiCall::Generate(const TPipe &tpipe, const std::vector<ascir::AxisId> ¤t_axis,
const vector<std::reference_wrapper<const Tensor>> &inputs,
const vector<std::reference_wrapper<const Tensor>> &outputs,
string &result) const {
(void)current_axis;
GE_CHK_BOOL_RET_STATUS((!inputs.empty()) && (!outputs.empty()), ge::FAILED,
"Codegen input or output tensor is empty");
const auto &x0 = inputs[0].get();
const auto &y = outputs[0].get();
size_t concat_dim;
GE_ASSERT_SUCCESS(ParseConcatDim(x0, y, concat_dim), "Failed to parse concat dim");
std::string dtype_name;
GE_CHK_STATUS_RET(Tensor::DtypeName(y.dtype, dtype_name), "Codegen get data type:%d failed",
static_cast<int32_t>(y.dtype));
int64_t life_time_axis_id = -1L;
int64_t id = -1L;
auto it = this->tmp_buf_id.find(life_time_axis_id);
if (it != this->tmp_buf_id.end()) {
id = it->second;
}
ConcatTiling concat_tiling;
GE_ASSERT_SUCCESS(InitializeTiling(concat_dim, inputs, y, concat_tiling));
std::stringstream ss;
if (use_concat_small_tail_api_) {
GE_ASSERT_TRUE(id != -1L, "ConcatApiCall cannot find tmp buffer id to use.");
if (concat_tiling.dst_col_size_expr.IsConstExpr()) {
GE_ASSERT_SUCCESS(CalcTiling(concat_dim, inputs, concat_tiling));
DefineInputList(concat_tiling, dtype_name, inputs, ss);
GE_ASSERT_SUCCESS(DefineConcatContext(concat_tiling, dtype_name, tpipe.tiler, ss));
DefineConcatTiling(concat_tiling, ss);
ss << "ConcatExtendV2(concat_context, tiling, " << y << ", " << tpipe.tmp_buf << "_" << std::to_string(id)
<< ");" << std::endl;
} else {
DefineInputList(concat_tiling, dtype_name, inputs, ss);
GE_ASSERT_SUCCESS(DefineConcatContext(concat_tiling, dtype_name, tpipe.tiler, ss));
DefineConcatShape(concat_tiling, tpipe.tiler, ss);
ss << "ConcatExtendV2Dyn(concat_context, concat_shape, " << y << ", " << tpipe.tmp_buf << "_" << std::to_string(id)
<< ");" << std::endl;
}
} else {
if (IsAllAligned(concat_tiling, concat_tiling.src_col_actual_size_exprs)) {
GE_ASSERT_SUCCESS(GenerateForAllAligned(inputs, y, concat_tiling, tpipe.tiler, ss));
} else {
GE_ASSERT_TRUE(id != -1L, "ConcatApiCall cannot find tmp buffer id to use.");
ss << "uint32_t concat_dim = " << concat_dim << ";" << std::endl;
GenConcatParams(inputs, y, tpipe.tiler, dtype_name, ss);
ss << "ConcatExtend<" << dtype_name << ", " << y.vectorized_axis.size() << ", " << std::to_string(inputs.size())
<< ">(dst_params, srcs_params, concat_dim, " << tpipe.tmp_buf << "_" << std::to_string(id)
<< ");" << std::endl;
}
}
result = ss.str();
return ge::SUCCESS;
}
ge::Status ConcatApiCall::GenerateForAllAligned(const vector<std::reference_wrapper<const Tensor>> &inputs,
const Tensor &y,
const ConcatApiCall::ConcatTiling &tiling,
const Tiler &tiler,
std::stringstream &ss) {
std::string dtype_name;
GE_CHK_STATUS_RET(Tensor::DtypeName(y.dtype, dtype_name), "Codegen get data type:%d failed",
static_cast<int32_t>(y.dtype));
GenConcatTilingForAllAligned(tiling, tiler, ss);
GenSrcTensors(inputs, dtype_name, ss);
ss << "ConcatAllAligned<" << dtype_name << ", " << inputs.size() << ">("
<< tiler.ActualSize(tiling.total_rows_expr) << ", " << "concat_tiling"
<< ", " << y << ", " << "concat_src_tensors"
<< ");" << std::endl;
return ge::SUCCESS;
}
void ConcatApiCall::GenConcatParams(const std::vector<std::reference_wrapper<const Tensor>> &inputs, const Tensor &y,
const Tiler &tiler, const string &dtype_name, std::stringstream &ss) {
ss << "const ConcatParams<" + dtype_name + ", " << y.vectorized_axis.size() << ">"
<< " dst_params = {" << std::endl
<< ".shape = {";
for (size_t i = 0; i < y.vectorized_axis.size(); i++) {
auto pos = y.vectorized_axis_pos[i];
ss << tiler.Size(y.axis_size[pos], true);
if (i != y.vectorized_axis.size() - 1) {
ss << ", ";
}
}
ss << "}," << std::endl << ".stride = {";
for (size_t i = 0; i < y.vectorized_axis.size(); i++) {
ss << tiler.Size(y.vectorized_strides[i], true);
if (i != y.vectorized_axis.size() - 1) {
ss << ", ";
}
}
ss << "}," << std::endl << ".tensor = &" << y << "," << std::endl << "};" << std::endl;
ss << "const ConcatParams<" + dtype_name + ", " << y.vectorized_axis.size() << "> srcs_params["
<< std::to_string(inputs.size()) << "] = {" << std::endl;
for (auto &input : inputs) {
const auto &x = input.get();
ss << "{" << std::endl << ".shape = {";
for (size_t i = 0U; i < x.vectorized_axis.size(); i++) {
auto pos = x.vectorized_axis_pos[i];
ss << tiler.Size(x.axis_size[pos], true);
if (i != x.vectorized_axis.size()) {
ss << ", ";
}
}
ss << "}," << std::endl << ".stride = {";
for (size_t i = 0U; i < x.vectorized_axis.size(); i++) {
ss << tiler.Size(x.vectorized_strides[i], true);
if (i != x.vectorized_axis.size() - 1) {
ss << ", ";
}
}
ss << "}," << std::endl << ".tensor = &" << x << "," << std::endl << "}," << std::endl;
}
ss << "};" << std::endl;
}
Status ConcatApiCall::ParseConcatDim(const Tensor &x0, const Tensor &y, size_t &concat_dim) {
GELOGD("x0_t_name:%s, axis_id:%s, size:%s, strides:%s, v_axis_id:%s, v_axis_pos:%s, v_strides:%s", x0.name.c_str(),
VectorToStr(x0.axis).c_str(), VectorToStr(x0.axis_size).c_str(), VectorToStr(x0.axis_strides).c_str(),
VectorToStr(x0.vectorized_axis).c_str(), VectorToStr(x0.vectorized_axis_pos).c_str(),
VectorToStr(x0.vectorized_strides).c_str());
GELOGD("y_t_name:%s, axis_id:%s, size:%s, strides:%s, v_axis_id:%s, v_axis_pos:%s, v_strides:%s", y.name.c_str(),
VectorToStr(y.axis).c_str(), VectorToStr(y.axis_size).c_str(), VectorToStr(y.axis_strides).c_str(),
VectorToStr(y.vectorized_axis).c_str(), VectorToStr(y.vectorized_axis_pos).c_str(),
VectorToStr(y.vectorized_strides).c_str());
GE_CHK_BOOL_RET_STATUS(x0.vectorized_axis.size() == y.vectorized_axis.size(), ge::FAILED,
"Codegen concat input output vectorized_axis not equal");
bool find_concat_dim = false;
for (size_t i = 0; i < y.vectorized_axis.size(); i++) {
auto pos = y.vectorized_axis_pos[i];
if (af::SymbolicUtils::StaticCheckEq(x0.axis_size[pos], y.axis_size[pos]) != af::TriBool::kTrue) {
concat_dim = i;
find_concat_dim = true;
GELOGI("find concat_dim:%zu, ", concat_dim);
break;
}
}
GE_ASSERT_TRUE(find_concat_dim, "not find concat dim in vectorized_axis");
return ge::SUCCESS;
}
Status ConcatApiCall::InitializeTiling(size_t concat_dim, const vector<std::reference_wrapper<const Tensor>> &inputs,
const Tensor &y, ConcatApiCall::ConcatTiling &tiling) {
auto data_type_size = ge::GetSizeByDataType(y.dtype);
GE_ASSERT_TRUE(data_type_size > 0);
tiling.data_type_size = static_cast<uint32_t>(data_type_size);
tiling.total_rows_expr = af::ops::One;
tiling.dst_col_size_expr = af::ops::One;
for (size_t i = 0; i < y.vectorized_axis.size(); ++i) {
auto pos = y.vectorized_axis_pos[i];
auto axis_size_expr = y.axis_size[pos];
if (i < concat_dim) {
tiling.total_rows_expr = tiling.total_rows_expr * axis_size_expr;
} else {
tiling.dst_col_size_expr = tiling.dst_col_size_expr * axis_size_expr;
}
}
const auto &concat_dim_stride = y.vectorized_strides[concat_dim];
GE_ASSERT_SUCCESS(FindNonZeroStride(y.vectorized_strides,
static_cast<int32_t>(concat_dim) - 1,
tiling.dst_row_stride));
tiling.src_col_size_exprs.resize(inputs.size());
tiling.src_col_actual_size_exprs.resize(inputs.size());
tiling.src_non_zero_strides.resize(inputs.size(), af::ops::Zero);
tiling.src_row_strides.resize(inputs.size(), af::ops::Zero);
tiling.last_dim_size_exprs.resize(inputs.size());
for (size_t input_index = 0; input_index < inputs.size(); ++input_index) {
auto &x = inputs[input_index].get();
tiling.src_col_size_exprs[input_index] = af::ops::One;
tiling.last_dim_size_exprs[input_index] = x.axis_size[x.vectorized_axis_pos[x.vectorized_axis.size() - 1]];;
for (size_t i = concat_dim; i < x.vectorized_axis.size(); ++i) {
auto pos = x.vectorized_axis_pos[i];
auto axis_size_expr = x.axis_size[pos];
tiling.src_col_size_exprs[input_index] = tiling.src_col_size_exprs[input_index] * axis_size_expr;
}
tiling.src_col_actual_size_exprs[input_index] = x.axis_size[x.vectorized_axis_pos[concat_dim]] * concat_dim_stride;
GE_ASSERT_SUCCESS(FindNonZeroStride(x.vectorized_strides,
static_cast<int32_t>(x.vectorized_strides.size()) - kOffsetSecondLastStride,
tiling.src_non_zero_strides[input_index]));
GE_ASSERT_SUCCESS(FindNonZeroStride(x.vectorized_strides,
static_cast<int32_t>(concat_dim) - 1,
tiling.src_row_strides[input_index]));
auto is_padded = af::SymbolicUtils::StaticCheckEq(tiling.src_non_zero_strides[input_index],
tiling.last_dim_size_exprs[input_index]) != af::TriBool::kTrue;
GE_CHK_BOOL_ONLY_LOG((!is_padded),
"Input[%zu] is_padded = %d, axes = %s, strides = %s",
input_index, static_cast<int32_t>(is_padded),
VectorToStr(x.vectorized_axis).c_str(), VectorToStr(x.vectorized_strides).c_str());
tiling.is_padded.emplace_back(is_padded);
tiling.any_padded = (tiling.any_padded || is_padded);
}
return ge::SUCCESS;
}
Status ConcatApiCall::CalcTiling([[maybe_unused]] size_t concat_dim, const vector<std::reference_wrapper<const Tensor>> &inputs,
ConcatApiCall::ConcatTiling &tiling) const {
const uint32_t kScaleToB16 = tiling.data_type_size / sizeof(uint16_t);
const uint32_t kEltNumPerBlock = kAddrListSize * kDataBlockSize / tiling.data_type_size;
GE_ASSERT_TRUE(!node_->attr.tmp_buffers.empty());
GE_ASSERT_TRUE(node_->attr.tmp_buffers.front().buf_desc.size.GetConstValue(tiling.tmp_buf_size));
GE_ASSERT_TRUE(tiling.dst_row_stride.GetConstValue(tiling.dst_col_size));
tiling.gcd = kAddrListSize;
for (size_t input_index = 0; input_index < inputs.size(); ++input_index) {
int64_t dim_size = -1;
GE_ASSERT_TRUE(tiling.src_col_size_exprs[input_index].GetConstValue(dim_size));
GE_ASSERT_TRUE(dim_size < std::numeric_limits<uint32_t>::max());
tiling.gcd = ascgen_utils::Gcd(tiling.gcd, static_cast<uint32_t>(dim_size));
tiling.src_col_sizes.emplace_back(dim_size);
}
tiling.dst_row_num_unit = tiling.dst_col_size * kScaleToB16;
tiling.max_repeat_times = (tiling.tmp_buf_size >> 10U) / (tiling.dst_col_size / tiling.gcd);
GE_ASSERT_TRUE(tiling.max_repeat_times > 0);
tiling.max_element_num = tiling.max_repeat_times * (tiling.dst_col_size / tiling.gcd) * kEltNumPerBlock;
tiling.max_orig_row_num = tiling.max_element_num / tiling.dst_col_size;
tiling.first_copy_repeat_times = tiling.max_repeat_times * kAddrListSize / kScaleToB16 / tiling.gcd;
tiling.last_trans_repeat_times = tiling.max_repeat_times * (tiling.dst_col_size / tiling.gcd);
tiling.per_repeat_size = (tiling.dst_col_size / tiling.gcd) * kEltNumPerBlock;
CalcTilingForInputs(inputs, kEltNumPerBlock, tiling);
GELOGI("ConcatTiling: gcd=%u, tmp_buf_size=%u, max_repeat_times=%u, max_element_num=%u, max_orig_row_num=%u",
tiling.gcd, tiling.tmp_buf_size, tiling.max_repeat_times, tiling.max_element_num, tiling.max_orig_row_num);
return ge::SUCCESS;
}
Status ConcatApiCall::CalcTilingForInputs(const vector<std::reference_wrapper<const Tensor>> &inputs,
size_t block_size, ConcatApiCall::ConcatTiling &concat_tiling) {
uint32_t buffer_offset = 0;
for (size_t input_index = 0; input_index < inputs.size(); ++input_index) {
int64_t src_row_stride = -1;
GE_ASSERT_TRUE(concat_tiling.src_row_strides[input_index].GetConstValue(src_row_stride));
concat_tiling.src_loop_strides.emplace_back(concat_tiling.max_orig_row_num * src_row_stride);
concat_tiling.src_buffer_offsets.emplace_back(buffer_offset);
buffer_offset +=
(concat_tiling.max_repeat_times * (concat_tiling.src_col_sizes[input_index] / concat_tiling.gcd) * block_size);
int64_t second_last_stride = -1;
GE_ASSERT_TRUE(concat_tiling.src_non_zero_strides[input_index].GetConstValue(second_last_stride));
int64_t last_dim_size = -1;
GE_ASSERT_TRUE(concat_tiling.last_dim_size_exprs[input_index].GetConstValue(last_dim_size));
GE_ASSERT_TRUE(second_last_stride > 0, "Failed to get second last stride");
concat_tiling.second_last_dim_strides.emplace_back(second_last_stride);
concat_tiling.gather_mask_dim_sizes.emplace_back(last_dim_size);
}
return ge::SUCCESS;
}
void ConcatApiCall::DefineConcatTiling(const ConcatApiCall::ConcatTiling &tiling, std::stringstream &ss) {
ss << "constexpr ConcatTiling<" << tiling.src_col_size_exprs.size() << "> tiling {" << std::endl;
ss << " .gcd = " << tiling.gcd << ", " << std::endl;
ss << " .tmp_buf_size = " << tiling.tmp_buf_size << ", " << std::endl;
ss << " .dst_dim_size = " << tiling.dst_col_size << ", " << std::endl;
ss << " .dst_row_num_unit = " << tiling.dst_row_num_unit << ", " << std::endl;
ss << " .max_repeat_times = " << tiling.max_repeat_times << ", " << std::endl;
ss << " .max_element_num = " << tiling.max_element_num << ", " << std::endl;
ss << " .max_orig_row_num = " << tiling.max_orig_row_num << ", " << std::endl;
ss << " .per_repeat_size = " << tiling.per_repeat_size << ", " << std::endl;
ss << " .first_copy_repeat_times = " << tiling.first_copy_repeat_times << ", " << std::endl;
ss << " .last_trans_repeat_times = " << tiling.last_trans_repeat_times << ", " << std::endl;
ss << " .src_dim_sizes = {";
for (const auto &src_dim_size : tiling.src_col_size_exprs) {
ss << src_dim_size << ", ";
}
ss << "}," << std::endl;
ss << " .src_strides = {";
for (const auto src_stride : tiling.src_loop_strides) {
ss << src_stride << ", ";
}
ss << "}," << std::endl;
ss << " .src_buffer_offsets = {";
for (const auto src_buffer_offset : tiling.src_buffer_offsets) {
ss << src_buffer_offset << ", ";
}
ss << "}," << std::endl;
ss << " .gather_mask_repeat_strides = {";
for (size_t i = 0; i < tiling.second_last_dim_strides.size(); ++i) {
auto repeat_stride =
tiling.is_padded[i] ? tiling.second_last_dim_strides[i] * tiling.data_type_size / kDataBlockSize : 0U;
ss << repeat_stride << ", ";
}
ss << "}," << std::endl;
ss << " .gather_mask_dim_sizes = {";
for (const auto scale : tiling.gather_mask_dim_sizes) {
ss << scale << ", ";
}
ss << "}" << std::endl;
ss << "};" << std::endl;
}
void ConcatApiCall::DefineConcatShape(const ConcatApiCall::ConcatTiling &tiling,
const Tiler &tiler,
std::stringstream &ss) {
ss << "const ConcatShape<" << tiling.src_col_size_exprs.size() << "> concat_shape {" << std::endl;
ss << " .dst_cols = " << tiler.Size(tiling.dst_row_stride) << ", " << std::endl;
ss << " .src_cols = {";
for (const auto &src_dim_size : tiling.src_col_size_exprs) {
ss << tiler.Size(src_dim_size) << ", ";
}
ss << "}," << std::endl;
if (tiling.any_padded) {
ss << " .src_row_strides = {";
for (const auto &src_row_stride : tiling.src_row_strides) {
ss << tiler.Size(src_row_stride) << ", ";
}
ss << "}," << std::endl;
ss << " .src_second_last_dim_strides = {";
for (const auto &second_last_dim_stride : tiling.src_non_zero_strides) {
ss << tiler.Size(second_last_dim_stride) << ", ";
}
ss << "}," << std::endl;
ss << " .gather_mask_dim_sizes = {";
for (const auto &last_dim_size_expr : tiling.last_dim_size_exprs) {
ss << tiler.Size(last_dim_size_expr) << ", ";
}
ss << "}," << std::endl;
}
ss << "};" << std::endl;
}
ge::Status ConcatApiCall::DefineConcatContext(const ConcatTiling &tiling,
const std::string &dtype_name,
const Tiler &tiler,
std::stringstream &ss) {
std::string sub_type = "DiffDim";
bool concat_same_dim = false;
std::set<int64_t> concat_axis_sizes;
if (tiling.dst_col_size_expr.IsConstExpr()) {
int64_t dst_col_size;
GE_ASSERT_TRUE(tiling.dst_col_size_expr.GetConstValue(dst_col_size));
int64_t dst_row_stride;
GE_ASSERT_TRUE(tiling.dst_row_stride.GetConstValue(dst_row_stride));
concat_axis_sizes.insert(tiling.src_col_sizes.cbegin(), tiling.src_col_sizes.cend());
concat_same_dim = (concat_axis_sizes.size() == 1) &&
(*concat_axis_sizes.cbegin() == 1) &&
(dst_col_size == dst_row_stride);
if (concat_same_dim) {
sub_type = "SameDim";
}
}
const auto &padding_type = (tiling.any_padded ? "Padded" : "");
const auto &concat_type = std::string("ConcatContext") + sub_type + padding_type;
GELOGI("Concat type = %s", concat_type.c_str());
ss << concat_type << "<" << dtype_name << ", " << tiling.src_col_size_exprs.size();
if (concat_same_dim) {
ss << ", " << *concat_axis_sizes.cbegin();
}
ss << "> concat_context;" << std::endl;
ss << "concat_context.total_row_num = " << tiler.ActualSize(tiling.total_rows_expr) << ";" << std::endl;
ss << "concat_context.input_list = &input_list;" << std::endl;
return ge::SUCCESS;
}
void ConcatApiCall::DefineInputList(const ConcatTiling &tiling, const std::string &dtype_name,
const std::vector<std::reference_wrapper<const Tensor>> &inputs,
std::stringstream &ss) {
ss << "ConcatInputList<" << dtype_name << ", " << inputs.size() << "> input_list {" << std::endl;
ss << " .src_tensor_base_addrs = {";
for (const auto &input : inputs) {
ss << "(" << dtype_name << " *)" << input << ".GetPhyAddr(), ";
}
ss << "}," << std::endl;
if (tiling.any_padded) {
ss << " .src_tensors = {";
for (size_t i = 0U; i < inputs.size(); ++i) {
std::string addr = tiling.is_padded[i] ? std::string("&") + inputs[i].get().Str() : "nullptr";
ss << addr << ", ";
}
ss << "}," << std::endl;
}
ss << "};" << std::endl;
}
bool ConcatApiCall::IsAllAligned(ConcatApiCall::ConcatTiling &tiling,
const std::vector<af::Expression> &col_size_exprs) {
tiling.dst_offsets.resize(col_size_exprs.size());
auto offset = af::ops::Zero;
for (size_t index = 0U; index < col_size_exprs.size(); ++index) {
auto size = af::Symbol(tiling.data_type_size) * col_size_exprs[index];
if (af::SymbolicUtils::StaticCheckEq(af::sym::Mod(size, af::Symbol(kDataBlockSize)), af::ops::Zero) != af::TriBool::kTrue) {
GELOGI("input[%zu] size = %s, is not aligned", index,
af::SymbolicUtils::ToString(col_size_exprs[index]).c_str());
return false;
}
if (!col_size_exprs[index].IsConstExpr()) {
tiling.all_static = false;
}
tiling.dst_offsets[index] = offset;
offset = offset + col_size_exprs[index];
}
return true;
}
void ConcatApiCall::GenConcatTilingForAllAligned(const ConcatApiCall::ConcatTiling &tiling, const Tiler &tiler,
std::stringstream &ss) {
const auto qualifier = tiling.all_static ? "constexpr " : "const ";
ss << qualifier;
ss << "ConcatTilingAllAligned<" << tiling.src_col_size_exprs.size() << "> concat_tiling {" << std::endl;
ss << " .dst_col_size = " << tiler.Size(tiling.dst_row_stride, true) << "," << std::endl;
ss << " .src_col_sizes = { ";
for (const auto &src_col_size : tiling.src_col_actual_size_exprs) {
ss << tiler.Size(src_col_size, true) << ", ";
}
ss << "}," << std::endl;
ss << " .dst_offsets = { ";
for (const auto &src_col_size : tiling.dst_offsets) {
ss << tiler.Size(src_col_size, true) << ", ";
}
ss << "}," << std::endl;
ss << "};" << std::endl;
}
void ConcatApiCall::GenSrcTensors(const std::vector<std::reference_wrapper<const Tensor>> &inputs,
const std::string &dtype_name, std::stringstream &ss) {
ss << "LocalTensor<" << dtype_name << "> concat_src_tensors[] { ";
for (auto &input : inputs) {
const auto &x = input.get();
ss << x << ", ";
}
ss << "};" << std::endl;
}
[[maybe_unused]] static ApiCallRegister<ConcatApiCall> register_concat_api_call("ConcatApiCall");
}