* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "ir/transforms/passes.h"
#include <algorithm>
#include <functional>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "core/error.h"
#include "core/logging.h"
#include "ir/function.h"
#include "ir/program.h"
#include "ir/transforms/ir_property.h"
#include "ir/transforms/pass_context.h"
#include "ir/verifier/verifier.h"
namespace pypto {
namespace ir {
Pass::Pass() : impl_(nullptr) {}
Pass::Pass(std::shared_ptr<PassImpl> impl) : impl_(std::move(impl)) {}
Pass::~Pass() = default;
Pass::Pass(const Pass& other) = default;
Pass& Pass::operator=(const Pass& other) = default;
Pass::Pass(Pass&& other) noexcept = default;
Pass& Pass::operator=(Pass&& other) noexcept = default;
ProgramPtr Pass::operator()(const ProgramPtr& program) const
{
INTERNAL_CHECK(impl_) << "Pass has null implementation";
INTERNAL_CHECK(program) << "Pass cannot run on null program";
auto* ctx = PassContext::Current();
if (ctx) {
ctx->RunBeforePass(*this, program);
}
ProgramPtr result = (*impl_)(program);
INTERNAL_CHECK(result) << "Pass '" << GetName() << "' returned null program";
if (ctx) {
ctx->RunAfterPass(*this, result);
}
return result;
}
ProgramPtr Pass::run(const ProgramPtr& program) const { return (*this)(program); }
std::string Pass::GetName() const
{
if (!impl_) {
return "NullPass";
}
return impl_->GetName();
}
IRPropertySet Pass::GetRequiredProperties() const
{
if (!impl_) {
return {};
}
return impl_->GetRequiredProperties();
}
IRPropertySet Pass::GetProducedProperties() const
{
if (!impl_) {
return {};
}
return impl_->GetProducedProperties();
}
IRPropertySet Pass::GetInvalidatedProperties() const
{
if (!impl_) {
return {};
}
return impl_->GetInvalidatedProperties();
}
namespace {
* \brief Pass implementation that wraps a program transform function
*/
class ProgramPassImpl : public PassImpl {
public:
ProgramPassImpl(std::function<ProgramPtr(const ProgramPtr&)> transform, std::string name, PassProperties properties)
: transform_(std::move(transform)), name_(std::move(name)), properties_(properties)
{}
ProgramPtr operator()(const ProgramPtr& program) override
{
INTERNAL_CHECK(program) << "ProgramPass cannot run on null program";
return transform_(program);
}
[[nodiscard]] std::string GetName() const override { return name_.empty() ? "ProgramPass" : name_; }
[[nodiscard]] IRPropertySet GetRequiredProperties() const override { return properties_.required; }
[[nodiscard]] IRPropertySet GetProducedProperties() const override { return properties_.produced; }
[[nodiscard]] IRPropertySet GetInvalidatedProperties() const override { return properties_.invalidated; }
private:
std::function<ProgramPtr(const ProgramPtr&)> transform_;
std::string name_;
PassProperties properties_;
};
* \brief Pass implementation that applies a function transform to each function in program
*/
class FunctionPassImpl : public PassImpl {
public:
FunctionPassImpl(
std::function<FunctionPtr(const FunctionPtr&)> transform, std::string name, PassProperties properties)
: transform_(std::move(transform)), name_(std::move(name)), properties_(properties)
{}
ProgramPtr operator()(const ProgramPtr& program) override
{
INTERNAL_CHECK(program) << "FunctionPass cannot run on null program";
std::vector<FunctionPtr> transformed_functions;
transformed_functions.reserve(program->functions_.size());
for (const auto& entry : program->functions_) {
const auto& func = entry.second;
FunctionPtr transformed_func = transform_(func);
transformed_functions.push_back(transformed_func);
}
return std::make_shared<const Program>(transformed_functions, program->name_, program->span_);
}
[[nodiscard]] std::string GetName() const override { return name_.empty() ? "FunctionPass" : name_; }
[[nodiscard]] IRPropertySet GetRequiredProperties() const override { return properties_.required; }
[[nodiscard]] IRPropertySet GetProducedProperties() const override { return properties_.produced; }
[[nodiscard]] IRPropertySet GetInvalidatedProperties() const override { return properties_.invalidated; }
private:
std::function<FunctionPtr(const FunctionPtr&)> transform_;
std::string name_;
PassProperties properties_;
};
}
namespace pass {
Pass CreateProgramPass(
std::function<ProgramPtr(const ProgramPtr&)> transform, const std::string& name, const PassProperties& properties)
{
return Pass(std::make_shared<ProgramPassImpl>(std::move(transform), name, properties));
}
Pass CreateFunctionPass(
std::function<FunctionPtr(const FunctionPtr&)> transform, const std::string& name, const PassProperties& properties)
{
return Pass(std::make_shared<FunctionPassImpl>(std::move(transform), name, properties));
}
Pass RunVerifier(const std::vector<std::string>& disabled_rules)
{
auto disabled_rules_snapshot = std::make_shared<const std::vector<std::string>>(disabled_rules);
return CreateProgramPass(
[disabled_rules_snapshot](const ProgramPtr& program) -> ProgramPtr {
IRVerifier verifier = IRVerifier::CreateDefault();
for (const auto& rule_name : *disabled_rules_snapshot) {
verifier.DisableRule(rule_name);
}
auto diagnostics = verifier.Verify(program);
if (!diagnostics.empty()) {
std::string report = IRVerifier::GenerateReport(diagnostics);
IR_LOGI() << "IR Verification Report:\n" << report;
}
return program;
},
"IRVerifier");
}
}
namespace pass {
static Pass MakeIdentityPass(const std::string& name)
{
return CreateProgramPass([](const ProgramPtr& p) { return p; }, name);
}
Pass InitMemRef() { return MakeIdentityPass("InitMemRef"); }
Pass BasicMemoryReuse() { return MakeIdentityPass("BasicMemoryReuse"); }
Pass AllocateMemoryAddr() { return MakeIdentityPass("AllocateMemoryAddr"); }
Pass OutlineIncoreScopes() { return MakeIdentityPass("OutlineIncoreScopes"); }
Pass ConvertTensorToBlockOps() { return MakeIdentityPass("ConvertTensorToBlockOps"); }
Pass FlattenCallExpr() { return MakeIdentityPass("FlattenCallExpr"); }
Pass NormalizeStmtStructure() { return MakeIdentityPass("NormalizeStmtStructure"); }
Pass FlattenSingleStmt() { return MakeIdentityPass("FlattenSingleStmt"); }
}
PassPipeline::PassPipeline() = default;
void PassPipeline::AddPass(Pass pass) { passes_.push_back(std::move(pass)); }
ProgramPtr PassPipeline::Run(const ProgramPtr& program) const
{
ProgramPtr current = program;
for (const auto& p : passes_) {
current = p(current);
}
return current;
}
std::vector<std::string> PassPipeline::GetPassNames() const
{
std::vector<std::string> names;
names.reserve(passes_.size());
for (const auto& p : passes_) {
names.push_back(p.GetName());
}
return names;
}
}
}