* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "data_task_builder.h"
namespace ffts {
namespace {
const int64_t kMaxBurstLen = 201326592;
static const std::map<CACHE_OPERATION, std::tuple<string, ContextType, bool, string>> kDataOptInfo = {
std::make_pair(CACHE_OPERATION::PREFETCH,
std::make_tuple(kPrefetchEnableBm, RT_CTX_TYPE_FLUSH_DATA, true, "prefetch")),
std::make_pair(CACHE_OPERATION::INVALIDATE,
std::make_tuple(kInvalidateBm, RT_CTX_TYPE_INVALIDATE_DATA, false, "invalidate")),
std::make_pair(CACHE_OPERATION::WRITE_BACK,
std::make_tuple(kWriteBackBm, RT_CTX_TYPE_WRITEBACK_DATA, false, "write back"))
};
Status SetDataContextId(size_t anchor_index, const ge::NodePtr &node, bool is_input,
std::vector<uint32_t> &ctx_id_vec) {
ge::GeTensorDescPtr desc_ptr = nullptr;
FFTS_LOGD("Node [%s] set context ID attribute with size [%zu].", node->GetName().c_str(), ctx_id_vec.size());
if (is_input) {
desc_ptr = node->GetOpDesc()->MutableInputDesc(anchor_index);
} else {
desc_ptr = node->GetOpDesc()->MutableOutputDesc(anchor_index);
}
FFTS_CHECK_NOTNULL(desc_ptr);
(void)ge::AttrUtils::SetListInt(desc_ptr, kTensorCtxId, ctx_id_vec);
return SUCCESS;
}
}
DataTaskBuilder::DataTaskBuilder() : operation_(CACHE_OPERATION::CACHE_OPERATION_BOTTOM) {}
DataTaskBuilder::DataTaskBuilder(CACHE_OPERATION operation) : operation_(operation), burst_len_(kMaxBurstLen) {}
DataTaskBuilder::~DataTaskBuilder() {}
Status DataTaskBuilder::GenManualDataCtxDef(const ge::NodePtr &node, domi::FftsPlusTaskDef *ffts_plus_task_def) {
FFTS_LOGD("DataTaskBuilder::GenManualContextDef begin, node name:%s, node type:%s.", node->GetName().c_str(),
node->GetType().c_str());
auto op_desc = node->GetOpDesc();
const auto &operation = kDataOptInfo.at(operation_);
string bm_name = std::get<0>(operation);
rtFftsPlusContextType_t context_type = std::get<1>(operation);
bool is_input = std::get<2>(operation);
string operation_name = std::get<3>(operation);
int64_t bm = 0;
if (!ge::AttrUtils::GetInt(op_desc, bm_name, bm) || bm == 0) {
return SUCCESS;
}
const auto &lib_name = op_desc->GetOpKernelLibName();
FFTS_LOGD("Node %s needs %s context, lib_name: %s.", node->GetName().c_str(), operation_name.c_str(), lib_name.c_str());
if (lib_name == kRtsFftsPlusOpStoreName) {
return SUCCESS;
}
auto indices = GetIndices(node);
size_t curr_num = 0;
std::vector<int32_t> cmo_idx;
for (auto index : indices) {
uint64_t curr_bm = 0;
SetBitOne(static_cast<uint32_t>(index), curr_bm);
if ((static_cast<uint64_t>(bm) & curr_bm) == 0) {
FFTS_LOGD("bm is %lld and current BM is %llu.", bm, curr_bm);
continue;
}
FFTS_LOGD("Node %s's tensor %d needs %s.", node->GetName().c_str(), index, operation_name.c_str());
std::vector<DataContextParam> data_ctx_params;
Status ret = MemorySlice::GenerateManualDataCtxParam(node, index, is_input, burst_len_, data_ctx_params);
if (ret != SUCCESS) {
continue;
}
if (ExceedMaxCtxNum(curr_num, data_ctx_params.size())) {
FFTS_LOGI("Exceeded the upper limit of prefetch context. Current total is %zu, pending total is %zu.",
curr_num, data_ctx_params.size());
continue;
}
curr_num += data_ctx_params.size();
FFTS_LOGD("Size of context is %zu, curr_num is %zu.", data_ctx_params.size(), curr_num);
size_t anchor_index = static_cast<size_t>(index);
std::vector<uint32_t> ctx_id_vec;
for (auto ¶m : data_ctx_params) {
domi::FftsPlusCtxDef *ffts_plus_ctx_def = ffts_plus_task_def->add_ffts_plus_ctx();
FFTS_CHECK_NOTNULL(ffts_plus_ctx_def);
ffts_plus_ctx_def->set_context_type(context_type);
ffts_plus_ctx_def->set_context_id(ffts_plus_task_def->ffts_plus_ctx_size() - 1);
ffts_plus_ctx_def->set_uniq_ctx_name(op_desc->GetName() + "_" + operation_name + "_" + to_string(index));
domi::FftsPlusDataCtxDef *data_ctx_def = ffts_plus_ctx_def->mutable_data_ctx();
FFTS_CHECK_NOTNULL(data_ctx_def);
uint32_t ctx_id = static_cast<uint32_t>(ffts_plus_task_def->ffts_plus_ctx_size() - 1);
ctx_id_vec.emplace_back(ctx_id);
FFTS_LOGD("Filling one %s context[%u] for tensor %zu in node %s.", operation_name.c_str(), ctx_id,
anchor_index, node->GetName().c_str());
Status status = FillManualDataCtx(anchor_index, node, param, ffts_plus_task_def, data_ctx_def);
if (status != SUCCESS) {
REPORT_FFTS_ERROR("[GenTask][InvldTsk][GenCtxDef]Fill context %s %zu failed. Op[%s], optype[%s]",
operation_name.c_str(), anchor_index, op_desc->GetName().c_str(), op_desc->GetType().c_str());
return status;
}
break;
}
(void)SetDataContextId(anchor_index, node, is_input, ctx_id_vec);
cmo_idx.emplace_back(index);
}
(void)ge::AttrUtils::SetListInt(node->GetOpDesc(), operation_name + "_idx", cmo_idx);
return SUCCESS;
}
Status DataTaskBuilder::GenAutoDataCtxDef(const ge::NodePtr &node, domi::FftsPlusTaskDef *ffts_plus_task_def) {
FFTS_LOGD("DataTaskBuilder::GenAutoContextDef begin, node name:%s, node type:%s", node->GetName().c_str(),
node->GetType().c_str());
auto op_desc = node->GetOpDesc();
const auto &operation = kDataOptInfo.at(operation_);
string bm_name = std::get<0>(operation);
rtFftsPlusContextType_t context_type = std::get<1>(operation);
bool is_input = std::get<2>(operation);
string operation_name = std::get<3>(operation);
int64_t bm = 0;
if (!ge::AttrUtils::GetInt(op_desc, bm_name, bm) || bm == 0) {
return SUCCESS;
}
FFTS_LOGD("Node %s needs %s context", node->GetName().c_str(), operation_name.c_str());
auto indices = GetIndices(node);
size_t curr_num = 0;
std::vector<int32_t> cmo_idx;
for (auto index : indices) {
uint64_t curr_bm = 0;
SetBitOne(static_cast<uint32_t>(index), curr_bm);
if ((static_cast<uint64_t>(bm) & curr_bm) == 0) {
FFTS_LOGD("bm is %lld and current BM is %llu.", bm, curr_bm);
continue;
}
FFTS_LOGD("Node %s's tensor %d needs %s.", node->GetName().c_str(), index, operation_name.c_str());
std::vector<DataContextParam> param_nontail_tail;
Status ret = MemorySlice::GenerateAutoDataCtxParam(node, index, is_input, burst_len_, param_nontail_tail);
if (ret != SUCCESS) {
continue;
}
if (ExceedMaxCtxNum(curr_num, param_nontail_tail.size())) {
FFTS_LOGI("Exceeded the upper limit of prefetch context. Current total is %zu, pending total is %zu.",
curr_num, param_nontail_tail.size());
continue;
}
curr_num += param_nontail_tail.size();
FFTS_LOGD("curr_num is %zu, param_nontail_tail size is %zu", curr_num, param_nontail_tail.size());
std::vector<std::vector<DataContextParam>> data_ctx_params;
ThreadSliceMapPtr slice_info_ptr = nullptr;
slice_info_ptr = op_desc->TryGetExtAttr(kAttrSgtStructInfo, slice_info_ptr);
FFTS_CHECK_NOTNULL(slice_info_ptr);
for (size_t i = 0; i < static_cast<size_t>(slice_info_ptr->parallel_window_size); i++) {
data_ctx_params.push_back(param_nontail_tail);
}
size_t anchor_index = static_cast<size_t>(index);
std::vector<uint32_t> ctx_id_vec;
for (size_t i = 0; i < data_ctx_params.size(); i++) {
domi::FftsPlusCtxDef *ffts_plus_ctx_def = ffts_plus_task_def->add_ffts_plus_ctx();
FFTS_CHECK_NOTNULL(ffts_plus_ctx_def);
uint32_t ctx_id = static_cast<uint32_t>(ffts_plus_task_def->ffts_plus_ctx_size() - 1);
ctx_id_vec.emplace_back(ctx_id);
ffts_plus_ctx_def->set_context_type(context_type);
ffts_plus_ctx_def->set_context_id(ffts_plus_task_def->ffts_plus_ctx_size() - 1);
ffts_plus_ctx_def->set_uniq_ctx_name(op_desc->GetName() + "_" + operation_name + "_" + to_string(i) + "_" +
to_string(index));
domi::FftsPlusDataCtxDef *data_ctx_def = ffts_plus_ctx_def->mutable_data_ctx();
FFTS_CHECK_NOTNULL(data_ctx_def);
FFTS_LOGD("Filling one %s context for tensor %zu of node %s, window_id: %zu.", operation_name.c_str(),
anchor_index, node->GetName().c_str(), i);
Status status = FillAutoDataCtx(anchor_index, node, data_ctx_params[i], ffts_plus_task_def, data_ctx_def, i);
if (status != SUCCESS) {
REPORT_FFTS_ERROR("[GenTask][InvldTsk][GenCtxDef]Fill context %s %zu failed. Op[%s], optype[%s]",
operation_name.c_str(), anchor_index, op_desc->GetName().c_str(), op_desc->GetType().c_str());
return status;
}
}
(void)SetDataContextId(anchor_index, node, is_input, ctx_id_vec);
cmo_idx.emplace_back(index);
}
(void)ge::AttrUtils::SetListInt(node->GetOpDesc(), operation_name + "_idx", cmo_idx);
return SUCCESS;
}
Status DataTaskBuilder::GenDynamicDataCtxDef(const ge::NodePtr &node, domi::FftsPlusTaskDef *ffts_plus_task_def) {
FFTS_LOGD("DataTaskBuilder::GenDynamicDataCtxDef begin, node name:%s, node type:%s", node->GetName().c_str(),
node->GetType().c_str());
auto op_desc = node->GetOpDesc();
const auto &operation = kDataOptInfo.at(operation_);
string bm_name = std::get<0>(operation);
rtFftsPlusContextType_t context_type = std::get<1>(operation);
string operation_name = std::get<3>(operation);
int64_t bm = 0;
if (!ge::AttrUtils::GetInt(op_desc, bm_name, bm) || bm == 0) {
return SUCCESS;
}
FFTS_LOGD("Node %s needs %s context", node->GetName().c_str(), operation_name.c_str());
auto indices = GetIndices(node);
vector<uint32_t> context_id_list;
(void)ge::AttrUtils::GetListInt(node->GetOpDesc(), kAutoCtxIdList, context_id_list);
std::vector<int32_t> cmo_idx;
size_t count = 0;
for (auto index : indices) {
uint64_t curr_bm = 0;
SetBitOne(static_cast<uint32_t>(index), curr_bm);
if ((static_cast<uint64_t>(bm) & curr_bm) == 0) {
FFTS_LOGD("bm is %lld and current BM is %llu.", bm, curr_bm);
continue;
}
FFTS_LOGD("Node %s's tensor %d needs %s.", node->GetName().c_str(), index, operation_name.c_str());
cmo_idx.emplace_back(index);
Status ret = FillDynamicDataCtx(static_cast<size_t>(index), node, ffts_plus_task_def, context_type,
context_id_list);
if (ret == FAILED) {
return FAILED;
}
count++;
if (count == kMaxPretchNum) {
break;
}
}
(void)ge::AttrUtils::SetListInt(node->GetOpDesc(), operation_name + "_idx", cmo_idx);
FFTS_LOGD("Generated dynamic data context size: %zu bytes", cmo_idx.size());
return SUCCESS;
}
std::vector<int> DataTaskBuilder::GetIndices(const ge::NodePtr &node) const {
vector<int> indices;
if (operation_ == CACHE_OPERATION::PREFETCH) {
for (const auto &in_anchor : node->GetAllInDataAnchors()) {
if (!in_anchor) {
continue;
}
uint32_t idx = in_anchor->GetIdx();
auto desc_ptr = node->GetOpDesc()->GetInputDescPtr(idx);
if (desc_ptr == nullptr) {
continue;
}
if (desc_ptr->GetShape().IsScalar()) {
FFTS_LOGD("Node [%s] in tensor %u is a scalar.", node->GetName().c_str(), idx);
continue;
}
if (idx < kMaxIdx) {
indices.emplace_back(in_anchor->GetIdx());
}
}
} else {
uint32_t idx = 0;
for (size_t i = 0U; i < node->GetOpDesc()->GetOutputsSize(); ++i) {
auto desc_ptr = node->GetOpDesc()->GetOutputDescPtr(i);
if (desc_ptr == nullptr) {
continue;
}
if (desc_ptr->GetShape().IsScalar()) {
FFTS_LOGD("Node[%s] output tensor at index %zu is a scalar.", node->GetName().c_str(), i);
continue;
}
if (IsMemoryEmpty(*desc_ptr.get())) {
FFTS_LOGD("Node[%s] out tensor:%zu is memory empty.", node->GetName().c_str(), i);
continue;
}
if (idx < kMaxIdx) {
indices.emplace_back(idx);
idx++;
}
}
}
return indices;
}
bool DataTaskBuilder::ExceedMaxCtxNum(size_t curr_num, size_t pending_num) const {
if (operation_ == CACHE_OPERATION::PREFETCH) {
return (curr_num + pending_num) > kMaxPretchNum;
}
return false;
}
* we need to calculate the following params based on the memory slicing
* info. */
void DataTaskBuilder::FillManualThreadingParam(const DataContextParam ¶m,
domi::FftsPlusDataCtxDef *data_ctx_def) const {
FFTS_LOGD("start to fill Manual threading param.");
data_ctx_def->set_non_tail_len_inner(param.len_inner);
data_ctx_def->set_non_tail_num_inner(param.num_inner);
data_ctx_def->set_non_tail_num_outter(param.num_outter);
data_ctx_def->set_non_tail_stride_inner(param.stride_inner);
data_ctx_def->set_non_tail_stride_outter(param.stride_outter);
data_ctx_def->set_tail_len_inner(param.len_inner);
data_ctx_def->set_tail_num_inner(param.num_inner);
data_ctx_def->set_tail_num_outter(param.num_outter);
data_ctx_def->set_tail_stride_inner(param.stride_inner);
data_ctx_def->set_tail_stride_outter(param.stride_outter);
}
* between two sgt thread and the offset is equal to size of thread.
* All auto threads use the same data context and hardware use
* the thread id to differenciate them.
*
* For one non-tail thread the base offset is equal to:
* non-tail thread dim size * thread_id.
* For the tail thread the base offset is equal to:
* non-tail thread dim size * (thread_dim - 1) */
void DataTaskBuilder::FillAutoThreadingParam(const vector<DataContextParam> ¶ms,
domi::FftsPlusDataCtxDef *data_ctx_def, const uint32_t &slice_num) const {
if (params.size() <= 1 || slice_num < 1) {
return ;
}
if (data_ctx_def == nullptr) {
return;
}
auto no_tail_num = (slice_num == 1) ? 1 : (slice_num - 1);
FFTS_LOGD("start to fill auto threading param, params[1].base_addr_offset: %ld, slice_num: %u, addr_offset: %ld.",
params[1].base_addr_offset, slice_num, params[1].base_addr_offset / no_tail_num);
data_ctx_def->set_addr_offset(params[1].base_addr_offset / no_tail_num);
data_ctx_def->set_non_tail_len_inner(params[0].len_inner);
data_ctx_def->set_non_tail_num_inner(params[0].num_inner);
data_ctx_def->set_non_tail_num_outter(params[0].num_outter);
data_ctx_def->set_non_tail_stride_inner(params[0].stride_inner);
data_ctx_def->set_non_tail_stride_outter(params[0].stride_outter);
data_ctx_def->set_tail_len_inner(params[1].len_inner);
data_ctx_def->set_tail_num_inner(params[1].num_inner);
data_ctx_def->set_tail_num_outter(params[1].num_outter);
data_ctx_def->set_tail_stride_inner(params[1].stride_inner);
data_ctx_def->set_tail_stride_outter(params[1].stride_outter);
}
Status DataTaskBuilder::GetAddrBase(size_t in_anchor_index, const ge::NodePtr &node, uint64_t &addr_base) const {
vector<int64_t> input_addrs;
if (!ge::AttrUtils::GetListInt(node->GetOpDesc(), "input_addrs", input_addrs)) {
FFTS_LOGW("[GenTsk][PrefetchTsk][FillCtxt][Node %s, Type %s] Attribute input_addrs is empty.",
node->GetName().c_str(), node->GetType().c_str());
return SUCCESS;
}
if (in_anchor_index >= input_addrs.size()) {
FFTS_LOGW("[GenTsk][PrefetchTsk][FillCtxt][node %s, type %s] In anchor %zu, the value is greater than or equal to the size of input_addrs %zu.",
node->GetName().c_str(), node->GetType().c_str(), in_anchor_index, input_addrs.size());
return SUCCESS;
}
addr_base = static_cast<uint64_t>(input_addrs[in_anchor_index]);
return SUCCESS;
}
Status DataTaskBuilder::UpdateSrcSlotAndPfBm(domi::FftsPlusTaskDef *ffts_plus_task_def, uint32_t context_id) const {
FFTS_LOGD("Update src slot and pf bm for context %u", context_id);
FFTS_CHECK_NOTNULL(ffts_plus_task_def);
domi::FftsPlusCtxDef *ctx = ffts_plus_task_def->mutable_ffts_plus_ctx(static_cast<int>(context_id));
FFTS_CHECK_NOTNULL(ctx);
uint32_t context_type = ctx->context_type();
uint32_t prefetch_ctx_id = ffts_plus_task_def->ffts_plus_ctx_size() - 1;
if (context_type == RT_CTX_TYPE_AICORE || context_type == RT_CTX_TYPE_AIV) {
auto aic_aiv_ctx = ctx->mutable_aic_aiv_ctx();
return AddSrcSlotAndBmToCtx(prefetch_ctx_id, aic_aiv_ctx);
} else if (context_type == RT_CTX_TYPE_MIX_AIC || context_type == RT_CTX_TYPE_MIX_AIV) {
auto mix_aic_aiv_ctx = ctx->mutable_mix_aic_aiv_ctx();
return AddSrcSlotAndBmToCtx(prefetch_ctx_id, mix_aic_aiv_ctx);
} else {
REPORT_FFTS_ERROR("[DataTaskBuilder][UpdateSrcSlotAndPfBm] Context type %u, with ID %u, does not require prefetching.",
context_type, context_id);
return FAILED;
}
}
* Just for auto_mode and dynamic_mode
* Manual_mode will override it.
*
* Just record first context_id generate by node(B and C) the reason is that you can update all window context's
* (generate by A) succ_list according this first record.
*
* For example: a_1's succ_list(5, 9), window size: 4.
* A (1, 2, 3, 4)
* / \
* / \
* B(5, 6, 7, 8) C(9, 10, 11, 12)
*
* You can get the reset of context's succ_list generate by A:
* a_(1+1)'s succ_list(5+1, 9+1)
* a_(1+2)'s succ_list(5+2, 9+2)
* a_(1+3)'s succ_list(5+3, 9+3)
*/
Status DataTaskBuilder::GetSuccessorContextId(uint32_t out_anchor_index, const ge::NodePtr &node,
std::vector<uint32_t> &succ_list, uint32_t &cons_cnt) {
cons_cnt = 0;
auto anchors = node->GetAllOutDataAnchors();
auto output_size = anchors.size();
if (out_anchor_index >= output_size) {
REPORT_FFTS_ERROR("[GenTask][DataTskBuilder][GetSuccList]Output anchor index %u >= output size %zu of %s.",
out_anchor_index, output_size, node->GetName().c_str());
return FAILED;
}
auto output_anchor = anchors.at(out_anchor_index);
if (output_anchor == nullptr) {
return SUCCESS;
}
for (const auto &peer_in_anchor : output_anchor->GetPeerInDataAnchors()) {
FFTS_CHECK_NOTNULL(peer_in_anchor);
auto peer_in_node = peer_in_anchor->GetOwnerNode();
FFTS_CHECK_NOTNULL(peer_in_node);
vector<uint32_t> peer_in_context_id;
uint32_t ctx_id_tmp = 0;
auto peer_op = peer_in_node->GetOpDesc();
FFTS_CHECK_NOTNULL(peer_op);
* and find its successors. */
if (IsPhonyOp(peer_op)) {
FFTS_LOGD("Peer input op for output %d of %s is PhonyConcat %s.", peer_in_anchor->GetIdx(),
node->GetName().c_str(), peer_op->GetName().c_str());
for (const auto &peer_node_of_pc : peer_in_node->GetOutAllNodes()) {
auto peer_op_of_pc = peer_node_of_pc->GetOpDesc();
vector<uint32_t> peer_in_context_id_list;
(void)ge::AttrUtils::GetListInt(peer_op_of_pc, kAutoCtxIdList, peer_in_context_id_list);
if (peer_in_context_id_list.empty()) {
FFTS_LOGI("PhonyConcat [%s]: peer operation [%s] needs successor list but it does not have a context ID.",
peer_op->GetName().c_str(), peer_op_of_pc->GetName().c_str());
continue;
}
ctx_id_tmp = peer_in_context_id_list[0];
FFTS_LOGD("Peer input op for PhonyConcat is %s, context id is %u.",
peer_op_of_pc->GetName().c_str(), ctx_id_tmp);
peer_in_context_id.emplace_back(ctx_id_tmp);
}
} else {
vector<uint32_t> peer_in_context_id_list;
(void)ge::AttrUtils::GetListInt(peer_op, kAutoCtxIdList, peer_in_context_id_list);
if (peer_in_context_id_list.empty()) {
FFTS_LOGI("Node %s needs successor list but it has a successor %s which do not have a context id.",
node->GetName().c_str(), peer_op->GetName().c_str());
continue;
}
ctx_id_tmp = peer_in_context_id_list[0];
FFTS_LOGD("Peer input op of %s is %s, context id is %u.",
node->GetName().c_str(), peer_op->GetName().c_str(), ctx_id_tmp);
peer_in_context_id.emplace_back(ctx_id_tmp);
}
for (auto ele : peer_in_context_id) {
succ_list.emplace_back(ele);
cons_cnt++;
}
FFTS_LOGD("Total successors(%zu) for node %s is %s.", succ_list.size(), node->GetName().c_str(),
fe::StringUtils::IntegerVecToString(succ_list).c_str());
}
return SUCCESS;
}
void DataTaskBuilder::SetOperation(CACHE_OPERATION operation) {
operation_ = operation;
}
void DataTaskBuilder::SetBurstLen(int64_t burst_len) {
burst_len_ = burst_len;
}
void DataTaskBuilder::UpdateRedundantNodes(const ge::NodePtr &node, vector<ge::NodePtr> &redundant_nodes) {
auto op_desc = node->GetOpDesc();
map<string, vector<ge::MemReuseInfo>> mem_reuse_infos{};
mem_reuse_infos = op_desc->TryGetExtAttr(ge::ATTR_NAME_MEMORY_REUSE_INFO, mem_reuse_infos);
if (mem_reuse_infos.empty()) {
FFTS_LOGD("[GenTsk][DataTsk][MemReuse] The [node %s, type %s] has no mem_reuse_info to Redundant.",
node->GetName().c_str(), node->GetType().c_str());
return;
}
for (auto &reuse_infos : mem_reuse_infos) {
if (reuse_infos.second.empty()) {
continue;
}
for (auto &info : reuse_infos.second) {
if (info.node == nullptr) {
continue;
}
redundant_nodes.emplace_back(info.node);
}
}
}
Status DataTaskBuilder::UpdateSuccListWithMemReuse(const ge::NodePtr &node,
vector<ge::MemReuseInfo> &mem_reuse_infos,
domi::FftsPlusTaskDef *ffts_plus_task_def,
int &data_ctx_id,
const size_t &window_id) {
if (mem_reuse_infos.empty()) {
REPORT_FFTS_ERROR("[GenTsk][DataTsk][FillCtxt][node %s, type %s] mem_reuse_infos value is empty.",
node->GetName().c_str(), node->GetType().c_str());
return FAILED;
}
auto sub_graph = node->GetOwnerComputeGraph();
if (sub_graph == nullptr) {
FFTS_LOGD("[GenTsk][DataTsk][FillCtxt][node %s, type %s] can't get owner compute sub graph.",
node->GetName().c_str(), node->GetType().c_str());
return SUCCESS;
}
vector<ge::NodePtr> redundant_nodes;
for (auto &reuse_info : mem_reuse_infos) {
if (reuse_info.node == nullptr) {
FFTS_LOGD("[GenTsk][DataTsk][FillCtxt][node %s, type %s] mem reuse info has node nullptr.",
node->GetName().c_str(), node->GetType().c_str());
continue;
}
auto op_desc = reuse_info.node->GetOpDesc();
FFTS_CHECK_NOTNULL(op_desc);
if (IsPhonyOp(op_desc)) {
FFTS_LOGD("[GenTsk][DataTsk][FillCtxt][node %s, type %s] mem reuse node is phonyop.",
node->GetName().c_str(), node->GetType().c_str());
continue;
}
auto ai_graph = reuse_info.node->GetOwnerComputeGraph();
if (ai_graph == nullptr) {
FFTS_LOGD("[GenTsk][DataTsk][FillCtxt][node %s, type %s] can't get owner compute graph.",
node->GetName().c_str(), node->GetType().c_str());
continue;
}
if (ai_graph != sub_graph) {
FFTS_LOGD("[GenTsk][DataTsk][FillCtxt] nodes [%s, %s] and [%s, %s] are in different compute graphs.",
node->GetName().c_str(), node->GetType().c_str(), reuse_info.node->GetName().c_str(),
reuse_info.node->GetType().c_str());
continue;
}
if (find(redundant_nodes.begin(), redundant_nodes.end(), reuse_info.node) != redundant_nodes.end()) {
continue;
}
redundant_nodes.emplace_back(reuse_info.node);
UpdateRedundantNodes(reuse_info.node, redundant_nodes);
FFTS_LOGI("[GenTsk][DataTsk][FillCtxt][node %s, type %s] by [node %s, type %s] reuse mem.",
node->GetName().c_str(), node->GetType().c_str(), reuse_info.node->GetName().c_str(),
reuse_info.node->GetType().c_str());
UpdateInvalidCtxWithMemReuse(reuse_info.node, data_ctx_id, window_id, ffts_plus_task_def);
}
return SUCCESS;
}
Status DataTaskBuilder::GenInvalidSuccListWithMemReuse(const ge::NodePtr &node, size_t out_anchor_index,
domi::FftsPlusTaskDef *ffts_plus_task_def,
int &data_ctx_id, const size_t &window_id) {
FFTS_LOGD("[GenTsk][DataTsk][MemReuse] Node %s of type %s is ready to get out_anchor_index: %zu mem reuse info.",
node->GetName().c_str(), node->GetType().c_str(), out_anchor_index);
auto op_desc = node->GetOpDesc();
map<string, vector<ge::MemReuseInfo>> mem_reuse_infos{};
mem_reuse_infos = op_desc->TryGetExtAttr(ge::ATTR_NAME_MEMORY_REUSE_INFO, mem_reuse_infos);
if (mem_reuse_infos.empty()) {
FFTS_LOGD("[GenTsk][DataTsk][MemReuse][node %s, type %s] has no mem_reuse_info.",
node->GetName().c_str(), node->GetType().c_str());
return SUCCESS;
}
auto anchors = node->GetAllOutDataAnchors();
auto output_size = anchors.size();
if (out_anchor_index >= output_size) {
REPORT_FFTS_ERROR("[GenTask][DataTskBuilder][MemReuse] Output anchor index %zu >= output size %zu of %s.",
out_anchor_index, output_size, node->GetName().c_str());
return FAILED;
}
auto output_anchor = anchors.at(out_anchor_index);
FFTS_CHECK(output_anchor == nullptr, FFTS_LOGD("The output_anchor is a nullptr."), return SUCCESS);
bool is_exist_successor = false;
for (const auto &peer_in_anchor : output_anchor->GetPeerInDataAnchors()) {
if (peer_in_anchor == nullptr) {
continue;
}
auto peer_in_node = peer_in_anchor->GetOwnerNode();
if (peer_in_node == nullptr) {
continue;
}
is_exist_successor = true;
break;
}
if (is_exist_successor != true) {
FFTS_LOGD("[GenTsk][DataTsk][MemReuse][node %s, type %s]'s out anchor index: %zu has no successor.",
node->GetName().c_str(), node->GetType().c_str(), out_anchor_index);
return SUCCESS;
}
std::string mem_info_key = "output";
mem_info_key.append(std::to_string(out_anchor_index));
if (mem_reuse_infos.find(mem_info_key) == mem_reuse_infos.end() || mem_reuse_infos[mem_info_key].empty()) {
FFTS_LOGD("[GenTsk][DataTsk][MemReuse] Node %s, type %s, has no reuse info for out anchor index: %zu.",
node->GetName().c_str(), node->GetType().c_str(), out_anchor_index);
return SUCCESS;
}
if (UpdateSuccListWithMemReuse(node, mem_reuse_infos[mem_info_key], ffts_plus_task_def,
data_ctx_id, window_id) != SUCCESS) {
return FAILED;
}
return SUCCESS;
}
}