* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file recorder.cpp
* \brief
*/
#include "interface/configs/config_manager.h"
#include "passes/pass_mgr/pass_manager.h"
namespace npu::tile_fwk {
const std::string PROGRAM_ENTRY_FUNCTION_NAME = "PROGRAM_ENTRY";
void static MergeAllFuncDupIocast(Function* func)
{
if (func == nullptr) {
auto rootFunc = Program::GetInstance().GetFunctionByMagicName(PROGRAM_ENTRY_FUNCTION_NAME);
if (rootFunc != nullptr) {
auto calleeLists = rootFunc->GetCalleeFunctionList();
for (auto callee : calleeLists) {
MergeAllFuncDupIocast(callee);
}
}
return;
}
FE_LOGI("Merge Duplicated Iocast for function: %s", func->GetMagicName().c_str());
auto calleeLists = func->GetCalleeFunctionList();
if (calleeLists.size() == 0) {
return;
}
func->MergeFunctionDupIocast();
func->RemoveCallOpViewAssemble();
if (config::GetPassDefaultConfig(KEY_PRINT_GRAPH, false) && func->IsGraphType(GraphType::TENSOR_GRAPH)) {
func->DumpJsonFile(config::LogTensorGraphFolder() + "/" + func->GetRawName() + "_remove_dup.json");
func->DumpFile(config::LogTensorGraphFolder() + "/" + func->GetRawName() + "_remove_dup.tifwkgr");
}
for (auto callee : calleeLists) {
MergeAllFuncDupIocast(callee);
}
}
void RecordFunc::RecordDynFuncInner(
const std::vector<std::reference_wrapper<const Tensor>>& startArgsInputTensorList,
const std::vector<std::reference_wrapper<const Tensor>>& startArgsOutputTensorList,
const std::vector<std::pair<std::reference_wrapper<const Tensor>, std::reference_wrapper<const Tensor>>>&
inplaceArgs)
{
CHECK(FeError::INVALID_TYPE, config::GetFunctionType() == FunctionType::DYNAMIC)
<< "Function graph type: " << GetFunctionTypeNameDict().Find(config::GetFunctionType());
#if ENABLE_HIDDENLOOP
recordLoopFunc_ = std::make_unique<RecordLoopFunc>(
funcName + "_loop", FunctionType::DYNAMIC_LOOP, funcName + "_unused_hidden_record_func_loop_idx", LoopRange(1));
#endif
Program::GetInstance().BeginFunction(funcName, config::GetFunctionType());
std::shared_ptr<TensorSlotManager> manager = Program::GetInstance().GetTensorSlotManager();
for (auto& param : startArgsInputTensorList) {
manager->MarkInput(param.get());
}
for (auto& param : startArgsOutputTensorList) {
manager->MarkOutput(param.get());
}
for (auto& param : inplaceArgs) {
manager->MarkInplace(param.first.get(), param.second.get());
}
dynFunc_ = Program::GetInstance().GetCurrentFunction();
dynFunc_->SetUnderDynamicFunction(true);
dynFunc_->SetSpan(ir::Span::Current());
std::shared_ptr<DyndevFunctionAttribute> attr = std::make_shared<DyndevFunctionAttribute>();
attr->startArgsInputTensorList = startArgsInputTensorList;
attr->startArgsOutputTensorList = startArgsOutputTensorList;
attr->startArgsInputLogicalTensorList.resize(startArgsInputTensorList.size());
attr->startArgsOutputLogicalTensorList.resize(startArgsOutputTensorList.size());
for (size_t k = 0; k < startArgsInputTensorList.size(); k++) {
attr->startArgsInputLogicalTensorList[k] = startArgsInputTensorList[k].get().GetStorage(false);
}
for (size_t k = 0; k < startArgsOutputTensorList.size(); k++) {
attr->startArgsOutputLogicalTensorList[k] = startArgsOutputTensorList[k].get().GetStorage(false);
}
dynFunc_->SetDyndevAttribute(attr);
Program::GetInstance().SetCurrentDynamicFunction(dynFunc_);
}
RecordFunc::RecordFunc(const std::string& name) : funcName(FUNCTION_PREFIX + name)
{
ConfigManager::Instance().ResetLog();
Program::GetInstance().GetTensorSlotManager()->SetRecording(true);
Program::GetInstance().BeginFunction(funcName, config::GetFunctionType());
}
RecordFunc::RecordFunc(const std::string& name, const std::vector<std::reference_wrapper<const Tensor>>& explicitOpArgs)
: funcName(FUNCTION_PREFIX + name)
{
ConfigManager::Instance().ResetLog();
Program::GetInstance().GetTensorSlotManager()->SetRecording(true);
if (config::GetFunctionType() == FunctionType::DYNAMIC) {
RecordDynFuncInner(explicitOpArgs, {}, {});
} else {
Program::GetInstance().BeginFunction(
funcName, config::GetFunctionType(), GraphType::TENSOR_GRAPH, explicitOpArgs);
}
}
RecordFunc::RecordFunc(
const std::string& name, const std::vector<std::reference_wrapper<const Tensor>>& startArgsInputTensorList,
const std::vector<std::reference_wrapper<const Tensor>>& startArgsOutputTensorList,
const std::vector<std::pair<std::reference_wrapper<const Tensor>, std::reference_wrapper<const Tensor>>>&
inplaceArgs)
: funcName(FUNCTION_PREFIX + name)
{
ConfigManager::Instance().ResetLog();
Program::GetInstance().GetTensorSlotManager()->SetRecording(true);
RecordDynFuncInner(startArgsInputTensorList, startArgsOutputTensorList, inplaceArgs);
}
inline bool IsVerifyEnable() { return config::GetVerifyOption<bool>(KEY_ENABLE_PASS_VERIFY); }
void RecordFunc::EndFunction()
{
if (recordLoopFunc_) {
recordLoopFunc_.reset();
}
if (IsVerifyEnable()) {
FE_LOGI("FlowVerify has been enable.");
config::SetRunDataOption(KEY_FLOW_VERIFY_PATH, config::GetAbsoluteTopFolder() + "/verify");
}
Defer clean([this]() {
isEnd_ = true;
Program::GetInstance().GetTensorSlotManager()->SetRecording(false);
});
(void)Program::GetInstance().EndFunction(funcName);
if (dynFunc_) {
Program::GetInstance().SetLastFunction(dynFunc_);
if (dynFunc_->IsDyndev()) {
Program::GetInstance().ClearEmptyHiddenFunction();
dynFunc_->CleanRedundantOutCast();
dynFunc_->InferParamDirection();
auto attr = dynFunc_->GetDyndevAttribute();
attr->getTensorDataDescDict.clear();
dynFunc_->ApplyLoopCallOrderGroup();
if (config::GetVerifyOption<bool>(KEY_ENABLE_PASS_VERIFY)) {
Program::GetInstance().VerifyTensorGraph();
}
MergeAllFuncDupIocast(nullptr);
PassManager::Instance().RunPass(
Program::GetInstance(), *Program::GetInstance().GetFunctionByMagicName(PROGRAM_ENTRY_FUNCTION_NAME),
"FunctionUnroll");
Program::GetInstance().UpdateCompileTask();
}
Program::GetInstance().SetCurrentDynamicFunction(nullptr);
dynFunc_->SetUnderDynamicFunction(false);
}
}
RecordFunc::Iterator RecordFunc::begin()
{
if (recordLoopFunc_) {
return Iterator(*this, recordLoopFunc_->begin());
}
return Iterator(*this);
}
RecordFunc::IteratorEnd RecordFunc::end()
{
if (recordLoopFunc_) {
return IteratorEnd(*this, recordLoopFunc_->end());
}
return IteratorEnd(*this);
}
RecordFunc::Iterator RecordFunc::Iterator::operator++()
{
if (!wrappedIter_.has_value()) {
cur_ = 1;
return *this;
}
++(*wrappedIter_);
return *this;
}
bool RecordFunc::Iterator::operator!=(const IteratorEnd& rhs)
{
if (!wrappedIter_.has_value()) {
return cur_ != 1;
}
FE_ASSERT(rhs.wrappedEnd.has_value()) << "Input param rhs has no value";
bool result = *wrappedIter_ != *rhs.wrappedEnd;
return result;
}
RecordLoopFunc::RecordLoopFunc(
const std::string& name, FunctionType funcType, const std::string& iterName, const LoopRange& range,
const std::set<int>& unrollList, bool submitBeforeLoop, bool parallel)
: name_(FUNCTION_PREFIX + name),
iterName_(iterName),
loopRange_(std::make_shared<LoopRange>(range)),
submitBeforeLoop_(submitBeforeLoop),
parallel_(parallel),
funcType_(funcType)
{
CHECK(FeError::INVALID_TYPE, funcType == FunctionType::STATIC || funcType == FunctionType::DYNAMIC_LOOP)
<< "funcType: " << GetFunctionTypeNameDict().Find(funcType);
if (parallel_) {
for (auto& rlf : Program::GetInstance().GetLoopStack()) {
if (rlf.get().Getparallel()) {
FE_ASSERT(FeError::INVALID_OPERATION, !rlf.get().Getparallel())
<< "The parallel attribute value does not allow nesting";
}
}
}
Program::GetInstance().GetLoopStack().emplace_back(*this);
GenDefaultUnrollTimes(unrollList);
span_ = ir::Span::Current();
}
RecordLoopFunc::~RecordLoopFunc() { Program::GetInstance().GetLoopStack().pop_back(); }
bool RecordLoopFunc::IterationEnd()
{
auto result = Program::GetInstance().EndFunction(curPathFuncName_);
auto pathFunc = std::get<0>(result);
pathFunc->ApplyLoopCallOrderGroup();
Program::GetInstance().GetTensorSlotManager()->Restore();
auto isEnd = GetLoopAttr()->IterationEnd(CurUnrollTimes(), pathFunc, std::get<1>(result));
if (isEnd) {
endCount_ = 0;
}
return isEnd;
}
void RecordLoopFunc::BeginLoopFunction()
{
auto loopFuncName = name_ + "_Unroll" + std::to_string(CurUnrollTimes());
Program::GetInstance().BeginFunction(loopFuncName, FunctionType::DYNAMIC_LOOP);
currentLoopFunc_ = Program::GetInstance().GetCurrentFunction();
CHECK(FeError::IS_EXIST, currentLoopFunc_->InsertLoopIdxNameList(iterName_))
<< "Forbid duplicate name of loop idx. It names " << iterName_;
auto currentStep = CurUnrollTimes() == 1 ? loopRange_->Step() : loopRange_->Step() * CurUnrollTimes();
if (rangeOfEaceUnroll_.empty()) {
auto newRangeEnd = (UnrollTimesSize() == 1 ? loopRange_->End() : loopRange_->End() / currentStep * currentStep);
std::shared_ptr<LoopRange> newRange =
std::make_shared<LoopRange>(loopRange_->Begin(), newRangeEnd, currentStep);
rangeOfEaceUnroll_.push_back(newRange);
} else {
auto prevRange = rangeOfEaceUnroll_.back();
auto newRangeEnd =
(UnrollTimesSize() == 1 ?
loopRange_->End() :
prevRange->End() + (loopRange_->End() - prevRange->End()) / currentStep * currentStep);
std::shared_ptr<LoopRange> newRange = std::make_shared<LoopRange>(prevRange->End(), newRangeEnd, currentStep);
rangeOfEaceUnroll_.push_back(newRange);
}
auto range = rangeOfEaceUnroll_.back();
range->End().AsIntermediateVariable();
auto attr =
std::make_shared<DynloopFunctionAttribute>(iterName_, *range, *loopRange_, submitBeforeLoop_, parallel_);
currentLoopFunc_->SetDynloopAttribute(attr);
currentLoopFunc_->SetSpan(span_);
}
void RecordLoopFunc::EndLoopFunction()
{
auto loopFuncName = name_ + "_Unroll" + std::to_string(CurUnrollTimes());
Program::GetInstance().EndFunction(loopFuncName);
currentLoopFunc_ = nullptr;
}
bool RecordLoopFunc::MatchUnrollTimes(int unrollTimes)
{
CHECK(FeError::INVALID_VAL, unrollTimes > 0) << "unrollTimes[" << unrollTimes << "] must larger than zero!";
auto& curRlf = Program::GetInstance().GetLoopStack().back().get();
curRlf.customUnrollTimes_.emplace(unrollTimes);
if (!curRlf.hasManualUnroll_) {
curRlf.hasManualUnroll_ = true;
curRlf.dryRun_ = true;
}
if (curRlf.dryRun_) {
return false;
}
if (!curRlf.VisitedUnroll(unrollTimes)) {
curRlf.VisitUnroll(unrollTimes);
}
FE_ASSERT(curRlf.StillHaveUnrollTimes()) << "unrollTimes_ is empty.";
if (curRlf.CurUnrollTimes() == unrollTimes) {
return true;
}
if (curRlf.CurUnrollTimes() > 1 && !curRlf.CustomUnrollTimesMatched() && unrollTimes == 1) {
return true;
}
return false;
}
RecordLoopFunc::Iterator RecordLoopFunc::Iterator::operator++()
{
if (rlf_.dryRun_) {
return *this;
}
if (!rlf_.CustomUnrollTimesMatched()) {
scalar_ = scalar_ + rlf_.LoopStep();
cur_++;
} else {
FE_ASSERT(cur_ == 0) << "The cur_ = " << cur_;
scalar_ = scalar_ + rlf_.LoopStep() * rlf_.CurUnrollTimes();
cur_ += rlf_.CurUnrollTimes();
}
rlf_.IterationNext();
return *this;
}
bool RecordLoopFunc::Iterator::operator!=(const IteratorEnd& rhs)
{
(void)rhs;
if (rlf_.dryRun_) {
rlf_.dryRun_ = false;
FE_ASSERT(cur_ == 0) << "The cur_ = " << cur_;
if (rlf_.IsCustomUnrollTimes(rlf_.CurUnrollTimes())) {
scalar_.AsLoopEnd(true);
}
return true;
}
FE_ASSERT(rlf_.StillHaveUnrollTimes()) << "unrollTimes_ is empty.";
if (cur_ < rlf_.CurUnrollTimes()) {
if (cur_ == 0) {
scalar_.AsLoopBegin(true);
rlf_.IterationBegin();
if (rlf_.IsCustomUnrollTimes(rlf_.CurUnrollTimes())) {
scalar_.AsLoopEnd(true);
}
}
if (cur_ + 1 == rlf_.CurUnrollTimes()) {
scalar_.AsLoopEnd(true);
}
return true;
}
FE_ASSERT(cur_ == rlf_.CurUnrollTimes())
<< " cur_ = " << cur_ << ", rlf_.CurUnrollTimes() = " << rlf_.CurUnrollTimes();
if (rlf_.IterationEnd()) {
rlf_.EndLoopFunction();
rlf_.NextUnrollTimes();
if (!rlf_.StillHaveUnrollTimes()) {
return false;
}
}
FE_ASSERT(rlf_.StillHaveUnrollTimes()) << "unrollTimes_ is empty.";
cur_ = 0;
scalar_ = originalScalar_;
scalar_.AsLoopBegin(true);
rlf_.IterationBegin();
if (rlf_.IsCustomUnrollTimes(rlf_.CurUnrollTimes()) || cur_ + 1 == rlf_.CurUnrollTimes()) {
scalar_.AsLoopEnd(true);
}
return true;
}
RecordLoopFunc::Iterator RecordLoopFunc::begin() { return {*this, SymbolicScalar(iterName_)}; }
RecordLoopFunc::IteratorEnd RecordLoopFunc::end() { return {*this, SymbolicScalar(iterName_)}; }
void RecordLoopFunc::IterationBegin()
{
if (currentLoopFunc_ == nullptr) {
BeginLoopFunction();
}
curPathFuncName_ = name_ + "_Unroll" + std::to_string(CurUnrollTimes()) + GetLoopSuffix(endCount_++);
Program::GetInstance().GetTensorSlotManager()->Checkpoint();
Program::GetInstance().BeginFunction(curPathFuncName_, FunctionType::DYNAMIC_LOOP_PATH);
auto loopPathFunc = Program::GetInstance().GetCurrentFunction();
loopPathFunc->SetSpan(span_);
GetLoopAttr()->IterationBegin();
}
void RecordLoopFunc::IterationNext()
{
FE_ASSERT(customUnrollTimes_.empty() || customUnrollTimes_.count(1) > 0)
<< "Must have unroll 1 if user defined custom unroll times.";
}
bool RecordLoopFunc::Condition(const SymbolicScalar& cond, const std::string& file, int line)
{
bool result = GetLoopAttr()->AppendCond(cond, file, line);
FE_LOGI("[%s:%d]: %s", file.c_str(), line, result ? "true" : "false");
return result;
}
void RecordLoopFunc::GenDefaultUnrollTimes(const std::set<int>& unrollList)
{
unrollTimes_.clear();
visited_.clear();
if (!config::GetPlatformConfig("ONLY_MANUAL_UNROLL", false) && !unrollList.empty()) {
for (auto n : unrollList) {
unrollTimes_.emplace(n);
visited_.emplace(n);
}
}
unrollTimes_.emplace(1);
visited_.emplace(1);
}
void RecordLoopFunc::VisitUnroll(int unrollTimes)
{
FE_ASSERT(FeError::IS_EXIST, visited_.count(unrollTimes) == 0)
<< "unrollTimes[" << unrollTimes << "] already exists in visited.";
FE_ASSERT(FeError::IS_EXIST, unrollTimes_.count(unrollTimes) == 0)
<< "unrollTimes[" << unrollTimes << "] already exists..";
visited_.emplace(unrollTimes);
unrollTimes_.emplace(unrollTimes);
}
int RecordLoopFunc::CurUnrollTimes() const
{
FE_ASSERT(StillHaveUnrollTimes()) << "unrollTimes_ is empty.";
return *unrollTimes_.begin();
}
void RecordLoopFunc::NextUnrollTimes()
{
FE_ASSERT(StillHaveUnrollTimes()) << "unrollTimes_ is empty.";
unrollTimes_.erase(unrollTimes_.begin());
}
std::shared_ptr<DynloopFunctionAttribute> RecordLoopFunc::GetLoopAttr()
{
return currentLoopFunc_->GetDynloopAttribute();
}
const SymbolicScalar& RecordLoopFunc::LoopBegin() const { return loopRange_->Begin(); }
const SymbolicScalar& RecordLoopFunc::LoopStep() const { return loopRange_->Step(); }
const SymbolicScalar& RecordLoopFunc::LoopEnd() const { return loopRange_->End(); }
RecordIfBranch::operator bool() const
{
bool cond = Program::GetInstance().GetLoopStack().back().get().Condition(cond_, file_, line_);
return cond;
}
}