* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "task_builder/mode/thread_task_builder.h"
#include "common/sgt_slice_type.h"
#include "common/aicore_util_attr_define.h"
#include "graph/debug/ge_attr_define.h"
#include "inc/ffts_type.h"
namespace ffts {
namespace {
const std::vector<std::string> kCtxTypeAttrs = {kAttrAICpuCtxDef, kAttrDsaCtxDef};
}
TheadTaskBuilder::TheadTaskBuilder() {}
TheadTaskBuilder::~TheadTaskBuilder() {}
void TheadTaskBuilder::SetModeType(const ModeType &type) {
mode_type_ = type;
}
Status TheadTaskBuilder::GetNodeContextTypeByNode(const ge::NodePtr &node, TaskBuilderType &task_builder_type) const {
ge::OpDescPtr op_desc = node->GetOpDesc();
if (op_desc->HasAttr(kAttrAICoreCtxType)) {
int64_t ctx_type = static_cast<int64_t>(TaskBuilderType::EN_TASK_TYPE_RESERVED);
(void)ge::AttrUtils::GetInt(op_desc, kAttrAICoreCtxType, ctx_type);
task_builder_type = TaskBuilderType(ctx_type);
return SUCCESS;
}
if (kHCCLOpType.count(op_desc->GetType()) > 0) {
task_builder_type = TaskBuilderType::EN_TASK_TYPE_COLLECTION_COMMICATE;
return SUCCESS;
}
for (auto const &ctx_type_attr : kCtxTypeAttrs) {
FftsPlusCtxDefPtr ctx_def_ptr = nullptr;
ctx_def_ptr = op_desc->TryGetExtAttr(ctx_type_attr, ctx_def_ptr);
if (ctx_def_ptr == nullptr) {
continue;
}
ffts::ThreadSliceMapPtr slice_info_ptr = nullptr;
slice_info_ptr = op_desc->TryGetExtAttr(ffts::kAttrSgtStructInfo, slice_info_ptr);
if (slice_info_ptr == nullptr) {
continue;
}
auto innerkey = ctx_type_attr;
if (slice_info_ptr->thread_mode == kAutoMode) {
innerkey += kAutoSuffix;
}
auto iter = kDsaOrAicpuMap.find(innerkey);
if (iter != kDsaOrAicpuMap.end()) {
task_builder_type = iter->second;
return SUCCESS;
}
}
FftsPlusCtxDefPtr runtime_control_ctx_def = nullptr;
runtime_control_ctx_def = op_desc->TryGetExtAttr("FFTS_PLUS_TASK_DEF", runtime_control_ctx_def);
FFTS_LOGD("Node[%s] rts def:%lx.", node->GetName().c_str(), runtime_control_ctx_def);
if (runtime_control_ctx_def != nullptr) {
task_builder_type = (mode_type_ == ModeType::MANUAL_MODE_TYPE) ? TaskBuilderType::EN_TASK_TYPE_RUNTIME_CONTROL :
TaskBuilderType::EN_TASK_TYPE_RUNTIME_CONTROL_AUTO;
return SUCCESS;
}
FFTS_LOGE("Node [%s] of type [%s] cannot obtain context type.", node->GetName().c_str(), node->GetType().c_str());
return FAILED;
}
FFTSPlusTaskBuilderPtr TheadTaskBuilder::GetTaskBuilder(TaskBuilderType task_builder_type) {
switch (task_builder_type) {
case TaskBuilderType::EN_TASK_TYPE_AIC_AIV:
return aic_aiv_task_builder_ptr_;
case TaskBuilderType::EN_TASK_TYPE_AIC_AIV_AUTO:
return aic_aiv_auto_task_builder_ptr_;
case TaskBuilderType::EN_TASK_TYPE_AIC_AIV_DYNAMIC:
return aic_aiv_dynamic_task_builder_ptr_;
case TaskBuilderType::EN_TASK_TYPE_MIX_AIC_AIV_DYNAMIC:
return mix_aic_aiv_dynamic_task_builder_ptr_;
case TaskBuilderType::EN_TASK_TYPE_MIX_AIC_AIV:
return mix_aic_aiv_task_builder_ptr_;
case TaskBuilderType::EN_TASK_TYPE_MIX_AIC_AIV_AUTO:
return mix_aic_aiv_auto_task_builder_ptr_;
case TaskBuilderType::EN_TASK_TYPE_MIX_L2_AIC_AIV:
return mix_L2_task_builder_ptr_;
case TaskBuilderType::EN_TASK_TYPE_COLLECTION_COMMICATE:
return collection_ops_task_builder_ptr_;
case TaskBuilderType::EN_TASK_TYPE_AICPU:
return aicpu_task_builder_ptr_;
case TaskBuilderType::EN_TASK_TYPE_AICPU_AUTO:
return aicpu_auto_task_builder_ptr_;
case TaskBuilderType::EN_TASK_TYPE_RUNTIME_CONTROL:
return runtime_ops_task_builder_ptr_;
case TaskBuilderType::EN_TASK_TYPE_RUNTIME_CONTROL_AUTO:
return runtime_ops_auto_task_builder_ptr_;
case TaskBuilderType::EN_TASK_TYPE_DSA:
return dsa_ops_task_builder_ptr_;
default:
return nullptr;
}
}
Status TheadTaskBuilder::GenerateDataTaskDef(const ge::NodePtr &node, domi::FftsPlusTaskDef *ffts_plus_task_def,
const ModeType &mode_type) const {
FFTS_LOGD("Current FFT mode type is %d.", static_cast<int>(mode_type));
if (mode_type == ModeType::MANUAL_MODE_TYPE) {
PrefetchTaskBuilder prefetch;
OutTaskBuilder invalid(CACHE_OPERATION::INVALIDATE);
OutTaskBuilder write_back(CACHE_OPERATION::WRITE_BACK);
if (prefetch.GenManualDataCtxDef(node, ffts_plus_task_def) != SUCCESS) {
return FAILED;
}
if (invalid.GenManualDataCtxDef(node, ffts_plus_task_def) != SUCCESS) {
return FAILED;
}
if (write_back.GenManualDataCtxDef(node, ffts_plus_task_def) != SUCCESS) {
return FAILED;
}
} else if (mode_type == ModeType::AUTO_MODE_TYPE) {
PrefetchAutoTaskBuilder prefetch_auto;
OutAutoTaskBuilder invalid_auto(CACHE_OPERATION::INVALIDATE);
OutAutoTaskBuilder write_back_auto(CACHE_OPERATION::WRITE_BACK);
if (prefetch_auto.GenAutoDataCtxDef(node, ffts_plus_task_def) != SUCCESS) {
return FAILED;
}
if (invalid_auto.GenAutoDataCtxDef(node, ffts_plus_task_def) != SUCCESS) {
return FAILED;
}
if (write_back_auto.GenAutoDataCtxDef(node, ffts_plus_task_def) != SUCCESS) {
return FAILED;
}
} else if (mode_type == ModeType::DYNAMIC_MODE_TYPE) {
PrefetchDynamicTaskBuilder prefetch_dyn;
OutDynamicTaskBuilder invalid_dyn(CACHE_OPERATION::INVALIDATE);
OutDynamicTaskBuilder write_back_dyn(CACHE_OPERATION::WRITE_BACK);
if (prefetch_dyn.GenDynamicDataCtxDef(node, ffts_plus_task_def) != SUCCESS) {
return FAILED;
}
if (invalid_dyn.GenDynamicDataCtxDef(node, ffts_plus_task_def) != SUCCESS) {
return FAILED;
}
if (write_back_dyn.GenDynamicDataCtxDef(node, ffts_plus_task_def) != SUCCESS) {
return FAILED;
}
}
return SUCCESS;
}
bool TheadTaskBuilder::IsNoCtx(const ge::NodePtr &node) const {
ge::OpDescPtr op_desc = node->GetOpDesc();
if (NO_NEED_GEN_TASK_OP_TYPE.count(op_desc->GetType()) != 0) {
return true;
}
if (CONTROL_OP_V2_TYPE.count(op_desc->GetType()) != 0) {
return true;
}
bool no_task = false;
(void)ge::AttrUtils::GetBool(op_desc, ge::ATTR_NAME_NOTASK, no_task);
if (no_task) {
return true;
}
return false;
}
Status TheadTaskBuilder::GenFftsPlusContextIdWithMemSet(std::vector<ge::NodePtr> &pre_sub_graph_nodes,
std::vector<ge::NodePtr> &memset_nodes,
ge::ComputeGraph &sgt_graph) const {
ge::NodePtr memset_node = nullptr;
for (auto &node : sgt_graph.GetDirectNode()) {
FFTS_CHECK_NOTNULL(node);
memset_node = nullptr;
memset_node = node->GetOpDesc()->TryGetExtAttr(fe::ATTR_NAME_MEMSET_NODE, memset_node);
if (memset_node == nullptr) {
continue;
}
if (memset_node->GetType() == fe::MEMSET_OP_TYPE) {
FFTS_LOGD("Get MemSet Node: %s, type: %s from Node: %s.", memset_node->GetName().c_str(),
memset_node->GetType().c_str(), node->GetName().c_str());
(void)ge::AttrUtils::SetBool(memset_node->GetOpDesc(), kOnlyAddContext, true);
auto iter = find(memset_nodes.begin(), memset_nodes.end(), memset_node);
if (iter != memset_nodes.end()) {
continue;
}
pre_sub_graph_nodes.emplace_back(memset_node);
memset_nodes.emplace_back(memset_node);
}
}
return SUCCESS;
}
}