* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#pragma once
#include "ir/builder.h"
#include "tilefwk/tilefwk.h"
#include "interface/inner/pre_def.h"
#include "interface/operation/operation.h"
using namespace pypto;
namespace npu::tile_fwk {
class IRContext {
public:
IRContext(const IRContext&) = delete;
IRContext& operator=(const IRContext&) = delete;
IRContext(IRContext&&) = delete;
IRContext& operator=(IRContext&&) = delete;
ir::VarPtr MakeVar(std::string name, ir::TypePtr type, ir::Span span)
{
auto var_name = GetVarName(name);
type_map_[var_name] = type;
return std::make_shared<ir::Var>(var_name, type, span);
}
ir::VarPtr MakeTempVar(ir::TypePtr type, ir::Span span)
{
auto var_name = GetVarName();
return MakeVar(var_name, type, span);
}
ir::IterArgPtr MakeIterArg(std::string name, ir::TypePtr type, ir::ExprPtr initVal, ir::Span span)
{
auto var_name = GetVarName(name);
type_map_[var_name] = type;
auto var = std::make_shared<ir::Var>(var_name, type, span);
return std::make_shared<ir::IterArg>(var, initVal);
}
ir::VarPtr MakeToken() { return MakeTempVar(ir::GetTokenType(), ir::Span::Unknown()); }
ir::TypePtr GetType(ir::VarPtr var) { return type_map_[var->name_]; }
void SetType(ir::VarPtr var, ir::TypePtr type) { type_map_[var->name_] = type; }
std::string GetOriginName(ir::VarPtr var) { return all_vars_[var->name_]; }
std::string GetVarName(const std::string& name = "")
{
auto var_name = name;
if (var_name.empty()) {
auto idx = temp_counter_++;
var_name = "$" + std::to_string(idx);
} else {
while (all_vars_.count(var_name)) {
auto idx = var_counter_[var_name]++;
var_name = name + "." + std::to_string(idx);
}
}
all_vars_[var_name] = name;
return var_name;
}
void Reset()
{
temp_counter_ = 0;
type_map_.clear();
var_counter_.clear();
all_vars_.clear();
}
static IRContext& Get();
private:
IRContext() = default;
int64_t temp_counter_{0};
std::map<std::string, ir::TypePtr> type_map_;
std::map<std::string, int64_t> var_counter_;
std::map<std::string, std::string> all_vars_;
};
class IRBuilder : public ir::IRBuilder {
public:
IRBuilder();
IRBuilder(const IRBuilder&) = delete;
IRBuilder& operator=(const IRBuilder&) = delete;
IRBuilder(IRBuilder&&) = delete;
IRBuilder& operator=(IRBuilder&&) = delete;
std::shared_ptr<RawTensor> CreateRawTensor(
DataType t, Shape shape, TileOpFormat format = TileOpFormat::TILEOP_ND, std::string name = "");
std::shared_ptr<RawTensor> CreateRawTensor(
DataType t, std::vector<SymbolicScalar> shape, TileOpFormat format = TileOpFormat::TILEOP_ND,
std::string name = "");
LogicalTensorPtr CreateTensorVar(
DataType t, Shape shape, TileOpFormat format = TileOpFormat::TILEOP_ND, std::string name = "");
LogicalTensorPtr CreateTensorVar(
DataType t, Shape shape, std::vector<SymbolicScalar> validShape, TileOpFormat format = TileOpFormat::TILEOP_ND,
std::string name = "");
LogicalTensorPtr CreateTensorVar(
std::shared_ptr<RawTensor> rawTensor, Offset offset, Shape shape, std::vector<SymbolicScalar> validShape = {});
LogicalTensorPtr CreateTensorVar(
Function& f, DataType t, Shape shape, TileOpFormat format = TileOpFormat::TILEOP_ND, std::string name = "");
LogicalTensorPtr CreateTensorVar(
Function& f, DataType t, Shape shape, std::vector<SymbolicScalar> validShape,
TileOpFormat format = TileOpFormat::TILEOP_ND, std::string name = "");
LogicalTensorPtr CreateTensorVar(
Function& f, std::shared_ptr<RawTensor> rawTensor, Offset offset, Shape shape,
std::vector<SymbolicScalar> validShape = {});
Operation& CreateTensorOpStmt(
Function& f, const Opcode opCode, const LogicalTensors& iOperands, const LogicalTensors& oOperands,
ir::Span span = ir::Span::Unknown());
ir::TensorOpStmtPtr CreateTensorOpStmt(
std::vector<ir::VarPtr> result, ir::VarPtr result_token, std::string opcode, std::vector<ir::ExprPtr> args,
std::vector<ir::VarPtr> tokens, std::vector<std::pair<std::string, std::any>> attrs, ir::Span span);
std::shared_ptr<Function> CreateFunction(
std::string name, LogicalTensors params, ir::StmtPtr body, ir::Span span = ir::Span::Unknown());
SymbolicScalar CreateConstInt(int64_t value);
SymbolicScalar CreateScalarVar(std::string sym);
ir::VarPtr CreateVarLike(std::string name, ir::ExprPtr value);
ir::AssignStmtPtr CreateAssignStmt(ir::VarPtr var, ir::ExprPtr value, ir::Span span);
ir::SeqStmtsPtr CreateSeqStmts(std::vector<ir::StmtPtr> stmts, ir::Span span);
ir::IfStmtPtr CreateIfStmt(
ir::ExprPtr cond, ir::StmtPtr thenBody, std::optional<ir::StmtPtr> elseBody, std::vector<ir::VarPtr> returnVars,
ir::Span span);
ir::YieldStmtPtr CreateYieldStmt(std::vector<ir::ExprPtr> values, ir::Span span);
ir::ReturnStmtPtr CreateReturnStmt(std::vector<ir::ExprPtr> values, ir::Span span);
ir::ForStmtPtr CreateForStmt(
ir::VarPtr loopVar, ir::ExprPtr start, ir::ExprPtr stop, ir::ExprPtr step, std::vector<ir::IterArgPtr> iterArgs,
ir::StmtPtr body, std::vector<ir::VarPtr> returnVars, ir::Span span);
ir::IterArgPtr CreateIterArg(std::string name, ir::TypePtr type, ir::ExprPtr initValue, ir::Span span);
ir::IterArgPtr CreateIterArg(ir::VarPtr var, ir::ExprPtr initValue);
ir::WhileStmtPtr CreateWhileStmt(
ir::ExprPtr cond, std::vector<ir::IterArgPtr> iterArgs, ir::StmtPtr body, std::vector<ir::VarPtr> returnVars,
ir::Span span);
ir::BreakStmtPtr CreateBreakStmt(std::vector<ir::ExprPtr> values, ir::Span span);
ir::ContinueStmtPtr CreateContinueStmt(std::vector<ir::ExprPtr> values, ir::Span span);
ir::FunctionPtr CreateFunction(
std::string name, std::vector<ir::VarPtr> params, std::vector<ir::TypePtr> returnTypes, ir::StmtPtr body,
ir::Span span);
ir::ProgramPtr CreateProgram(std::vector<ir::FunctionPtr> functions, std::string name, ir::Span span);
ir::VarPtr CreateTokenVar(ir::Span span);
void EmitTensorStmts();
ir::ExprPtr None();
private:
IRContext& irContext_;
};
}