* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file expand_function.cpp
* \brief
*/
#include "passes/tensor_graph_pass/expand_function.h"
#include <map>
#include "interface/function/function.h"
#include "passes/pass_utils/pass_operation_utils.h"
#include "interface/tensor/raw_tensor.h"
#include "interface/tensor/logical_tensor.h"
#include "tilefwk/tilefwk.h"
#include "interface/inner/tilefwk.h"
#include "interface/program/program.h"
#include "interface/operation/operation_impl.h"
#include "interface/configs/config_manager.h"
#include "passes/pass_check/expand_function_checker.h"
#include "passes/statistics/tensor_and_tile_graph_statistic.h"
#include "passes/pass_log/pass_log.h"
#include "passes/pass_utils/graph_utils.h"
#include "tilefwk/error_code.h"
#define MODULE_NAME "ExpandFunction"
using namespace npu::tile_fwk;
namespace npu::tile_fwk {
const std::unordered_set<Opcode> ExpandFunction::kNotNeedExpandOps = {
Opcode::OP_VIEW, Opcode::OP_ASSEMBLE, Opcode::OP_NOP, Opcode::OP_ATOMIC_RMW};
Status ExpandFunction::ClearIOOperand(const std::vector<OperationPtr>& tensorOperations) const
{
for (auto& op : tensorOperations) {
for (auto& iOperand : op->GetIOperands()) {
if (iOperand == nullptr) {
APASS_LOG_ERROR_F(
Elements::Operation, "Op:%s[%d] input is null.%s", op->GetOpcodeStr().c_str(), op->GetOpMagic(),
GetFormatBacktrace(*op).c_str());
return FAILED;
}
iOperand->GetConsumers().clear();
iOperand->GetProducers().clear();
}
for (auto& oOperand : op->GetOOperands()) {
if (oOperand == nullptr) {
APASS_LOG_ERROR_F(
Elements::Operation, "Op:%s[%d] output is null.%s", op->GetOpcodeStr().c_str(), op->GetOpMagic(),
GetFormatBacktrace(*op).c_str());
return FAILED;
}
oOperand->GetConsumers().clear();
oOperand->GetProducers().clear();
}
}
return SUCCESS;
}
void ExpandFunction::ProcessForNotExpandOp(Function& function, Operation& op) const
{
auto& newOp = PassOperationUtils::AddOperation(function, op.GetOpcode(), op.GetIOperands(), op.GetOOperands(), nullptr, ir::Span::Unknown(), false);
newOp.SetOpAttribute(op.GetOpAttribute());
newOp.SetScopeInfo(op.GetScopeInfo());
newOp.CopyAttrFrom(op, OP_EMUOP_PREFIX);
if (op.HasAttribute(OpAttributeKey::inplaceIdx)) {
newOp.SetAttribute(OpAttributeKey::inplaceIdx, op.GetIntAttribute(OpAttributeKey::inplaceIdx));
}
if (op.HasAttribute(OpAttributeKey::rmwMode)) {
newOp.SetAttribute(OpAttributeKey::rmwMode, op.GetIntAttribute(OpAttributeKey::rmwMode));
}
}
Status ExpandFunction::DefaultEnabledPreCheck(Function& function)
{
ExpandFunctionChecker checker;
return checker.DoDefaultEnabledPreCheck(function);
}
Status ExpandFunction::PostCheck(Function& function)
{
ExpandFunctionChecker checker;
return checker.DoPostCheck(function);
}
Status ExpandFunction::VerifyScopeInfo(Function& function, std::ostringstream& oss) const
{
std::unordered_map<int, Operation::ScopeInfo> scopeInfoMap;
std::unordered_map<int, std::unordered_set<CoreType>> scopeCoreTypes;
for (auto& op : function.Operations(false)) {
const auto& info = op.GetScopeInfo();
if (info.scopeId == -1 && (info.allowParallelMerge || info.allowCrossScopeMerge)) {
oss << "Op " << op.GetOpcodeStr() << "[" << op.GetOpMagic()
<< "]: allowParallelMerge and allowCrossScopeMerge must be false when scopeId is -1.";
return FAILED;
}
if (info.scopeId != -1) {
auto it = scopeInfoMap.find(info.scopeId);
if (it != scopeInfoMap.end()) {
const auto& existing = it->second;
if (existing.allowParallelMerge != info.allowParallelMerge ||
existing.allowCrossScopeMerge != info.allowCrossScopeMerge) {
oss << "Op " << op.GetOpcodeStr() << "[" << op.GetOpMagic() << "]: scopeId=" << info.scopeId
<< " has conflicting allowParallelMerge or allowCrossScopeMerge settings.";
return FAILED;
}
} else {
scopeInfoMap[info.scopeId] = info;
}
scopeCoreTypes[info.scopeId].insert(op.GetCoreType());
}
}
for (auto& [scopeId, coreTypes] : scopeCoreTypes) {
if (coreTypes.count(CoreType::AIC) > 0 && coreTypes.count(CoreType::AIV) > 0) {
if (!GraphUtils::IsCVMixPlatform()) {
oss << "Cannot mix cube and vector op on a CV seperate platform in function: " << function.GetRawName()
<< ", please check your setting: sg_set_scope=" << scopeId;
return FAILED;
}
}
}
return SUCCESS;
}
Status ExpandFunction::RunOnFunction(Function& function)
{
APASS_LOG_INFO_F(Elements::Function, "Start ExpandFunction function [%s].", function.GetRawName().c_str());
std::ostringstream oss;
if (VerifyScopeInfo(function, oss) != SUCCESS) {
APASS_LOG_ERROR_C(
OperationErr::OP_SCOPE_ERROR, Elements::Function, "Function[%s] ScopeInfo verification failed: %s",
function.GetRawName().c_str(), oss.str().c_str());
return FAILED;
}
bool verifyResult = true;
for (auto& op : function.Operations(false)) {
auto verifyOperationEntry = OpcodeManager::Inst().GetVerifyOperationEntry(op.GetOpcode());
if (verifyOperationEntry) {
verifyResult = verifyResult && verifyOperationEntry(function, op, oss);
}
}
if (!verifyResult) {
APASS_LOG_ERROR_F(
Elements::Function, "FUnction[%s] ExpandFunction failed: %s", function.GetRawName().c_str(),
oss.str().c_str());
return FAILED;
}
if (Expandfunction(function) != SUCCESS) {
APASS_LOG_ERROR_F(Elements::Function, "Function[%s] ExpandFunction failed.", function.GetRawName().c_str());
return FAILED;
}
APASS_LOG_INFO_F(
Elements::Function, "Function[%s] operation size is: %zu after expansion.", function.GetMagicName().c_str(),
function.Operations().size());
APASS_LOG_INFO_F(Elements::Function, "End ExpandFunction function [%s].", function.GetRawName().c_str());
return SUCCESS;
}
Status ExpandFunction::Expandfunction(Function& function) const
{
if (!function.IsGraphType(GraphType::TENSOR_GRAPH)) {
APASS_LOG_INFO_F(Elements::Function, "Function %s is not static tensor graph, skip expanding.", function.GetRawName().c_str());
return SUCCESS;
}
function.expandFunctionAccelerate = true;
function.SetGraphType(GraphType::TILE_GRAPH);
std::vector<OperationPtr> tensorOperations;
auto operationViewer = function.Operations();
for (size_t i = 0; i < operationViewer.size(); i++) {
tensorOperations.emplace_back(operationViewer.operations_[i]);
}
function.ResetOperations();
if (ClearIOOperand(tensorOperations) != SUCCESS) {
APASS_LOG_ERROR_F(Elements::Operation, "ClearIOOperand failed.");
return FAILED;
}
for (auto& op : tensorOperations) {
if (op == nullptr) {
APASS_LOG_ERROR_F(Elements::Operation, "Encountered null operation in function.");
return FAILED;
}
if (op->GetOpcode() == Opcode::OP_PRINT) {
continue;
}
if (kNotNeedExpandOps.count(op->GetOpcode())) {
ProcessForNotExpandOp(function, *op);
continue;
}
ir::Span::SetCurrent(op->GetSpan());
config::SetSemanticLabel(op->GetSemanticLabel());
size_t opListPreSize = function.Operations(false).size();
if (ExpandOperation(function, *op) != SUCCESS) {
APASS_LOG_ERROR_F(Elements::Operation, "ExpandOperation failed.");
return FAILED;
}
auto opListPost = function.Operations(false);
if (op->GetOpcode() == Opcode::OP_ADDS) {
for (size_t i = opListPreSize; i < opListPost.size(); i++) {
auto& newOp = opListPost[i];
newOp.CopyAttrFrom(*op, OP_EMUOP_PREFIX);
}
}
ir::Span::ClearCurrent();
}
function.BuildTensorMap();
function.expandFunctionAccelerate = false;
return SUCCESS;
}
Status ExpandFunction::ExpandOperation(Function& function, Operation& op) const
{
const auto& info = op.GetScopeInfo();
std::vector<int64_t> scopeVec = {
static_cast<int64_t>(info.scopeId), static_cast<int64_t>(info.allowParallelMerge),
static_cast<int64_t>(info.allowCrossScopeMerge)};
config::SetPassOption(SG_SET_SCOPE, scopeVec);
ExpandOperationInto(function, op.GetTileShape(), op.GetOpcode(), op.GetIOperands(), op.GetOOperands(), op);
config::SetPassOption(SG_SET_SCOPE, std::vector<int64_t>{-1, 0, 0});
return SUCCESS;
}
void ExpandFunction::DoHealthCheckBefore(Function& function, const std::string& folderPath)
{
APASS_LOG_INFO_F(Elements::Operation, "Before ExpandFunction, Health Report: TensorGraph START");
std::string fileName = GetDumpFilePrefix(function, true);
HealthCheckTensorGraph(function, folderPath, fileName);
APASS_LOG_INFO_F(Elements::Operation, "Before ExpandFunction, Health Report: TensorGraph END");
}
}