* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include <any>
#include <cmath>
#include <cstddef>
#include <functional>
#include <iomanip>
#include <ios>
#include <map>
#include <memory>
#include <set>
#include <sstream>
#include <string>
#include <typeinfo>
#include <unordered_map>
#include <utility>
#include <vector>
#include "core/any_cast.h"
#include "core/dtype.h"
#include "core/error.h"
#include "core/logging.h"
#include "ir/core.h"
#include "ir/expr.h"
#include "ir/function.h"
#include "ir/kind_traits.h"
#include "ir/memref.h"
#include "ir/pipe.h"
#include "ir/program.h"
#include "ir/scalar_expr.h"
#include "ir/stmt.h"
#include "ir/transforms/base/visitor.h"
#include "ir/transforms/printer.h"
#include "ir/type.h"
namespace pypto {
namespace ir {
namespace {
std::string BlockDslDTypeToString(const DataType& dtype)
{
if (dtype == DataType::BF16) {
return "BFLOAT16";
}
return DTypeToString(dtype);
}
const char* TileLayoutToPythonName(TileLayout layout)
{
switch (layout) {
case TileLayout::none_box:
return "none_box";
case TileLayout::row_major:
return "row_major";
case TileLayout::col_major:
return "col_major";
default:
INTERNAL_CHECK(false) << "Unknown TileLayout in PrintTileView";
return "";
}
}
const char* TilePadToPythonName(TilePad pad)
{
switch (pad) {
case TilePad::null:
return "null";
case TilePad::zero:
return "zero";
case TilePad::max:
return "max";
case TilePad::min:
return "min";
default:
INTERNAL_CHECK(false) << "Unknown TilePad in PrintTileView";
return "";
}
}
const char* CompactModeToPythonName(CompactMode compact)
{
switch (compact) {
case CompactMode::null:
return "null";
case CompactMode::normal:
return "normal";
case CompactMode::row_plus_one:
return "row_plus_one";
default:
INTERNAL_CHECK(false) << "Unknown CompactMode in PrintTileView";
return "";
}
}
void PrintPythonIterArgTuple(std::ostringstream& stream, const std::vector<IterArgPtr>& iter_args)
{
stream << "(";
bool need_comma = false;
for (const auto& iter_arg : iter_args) {
if (need_comma) {
stream << ", ";
}
stream << iter_arg->iterVar_->name_;
need_comma = true;
}
if (iter_args.size() == 1) {
stream << ",";
}
stream << ")";
}
template <typename VisitExprFn>
void PrintPythonInitValuesClause(
std::ostringstream& stream, const std::vector<IterArgPtr>& iter_args, const VisitExprFn& visit_expr)
{
stream << "init_values=(";
bool need_comma = false;
for (const auto& iter_arg : iter_args) {
if (need_comma) {
stream << ", ";
}
visit_expr(iter_arg->initValue_);
need_comma = true;
}
if (iter_args.size() == 1) {
stream << ",";
}
stream << ")";
}
template <typename VisitExprFn>
void PrintPythonRangeHeader(
std::ostringstream& stream, const std::string& prefix, const ForStmtPtr& op, const VisitExprFn& visit_expr)
{
stream << "for " << op->loopVar_->name_;
if (!op->iterArgs_.empty()) {
stream << ", ";
PrintPythonIterArgTuple(stream, op->iterArgs_);
}
stream << " in " << prefix << ".range(";
visit_expr(op->start_);
stream << ", ";
visit_expr(op->stop_);
stream << ", ";
visit_expr(op->step_);
if (!op->iterArgs_.empty()) {
stream << ", ";
PrintPythonInitValuesClause(stream, op->iterArgs_, visit_expr);
}
stream << "):\n";
}
template <typename VisitExprFn>
void PrintPythonWhileHeader(
std::ostringstream& stream, const std::string& prefix, const WhileStmtPtr& op, const VisitExprFn& visit_expr)
{
stream << "for ";
PrintPythonIterArgTuple(stream, op->iterArgs_);
stream << " in " << prefix << ".while_(";
PrintPythonInitValuesClause(stream, op->iterArgs_, visit_expr);
stream << "):\n";
}
}
* \brief Python-style IR printer
*
* Prints IR nodes in Python syntax with type annotations and SSA-style control flow.
* This is the recommended printer for new code that outputs valid Python syntax.
*
* Key features:
* - Type annotations (e.g., x: pl.INT64, a: pl.Tensor[[4, 8], pl.FP32])
* - SSA-style if/for with pl.yield_() and pl.range()
* - Op attributes as keyword arguments
* - Program headers with # pypto_block.program: name
*/
class IRPythonPrinter : public IRVisitor {
using IRVisitor::VisitExpr_;
using IRVisitor::VisitStmt_;
public:
explicit IRPythonPrinter(std::string prefix = "pl") : prefix_(std::move(prefix)) {}
~IRPythonPrinter() override = default;
* \brief Print an IR node to a string in Python IR syntax
*
* \param node IR node to print (can be Expr, Stmt, Function, or Program)
* \return Python-style string representation
*/
std::string Print(const IRNodePtr& node);
std::string Print(const TypePtr& type);
protected:
PYPTO_IR_PRINTER_COMMON_VISITOR_OVERRIDES();
void VisitStmt_(const SectionStmtPtr& op) override;
private:
std::ostringstream stream_;
int indent_level_ = 0;
std::string prefix_;
ProgramPtr current_program_ = nullptr;
std::string GetIndent() const;
void IncreaseIndent();
void DecreaseIndent();
void PrintStmtBlock(const StmtPtr& stmt);
void VisitStmtBody(const StmtPtr& body, const std::vector<VarPtr>& return_vars = {});
void PrintYieldAssignmentVars(const std::vector<VarPtr>& return_vars);
void PrintYieldAssignment(const YieldStmtPtr& yield_stmt, const std::vector<VarPtr>& return_vars);
void VisitSeqStmtBody(const SeqStmtsPtr& seq_stmts, const std::vector<VarPtr>& return_vars);
void PrintReturnFromYield(const YieldStmtPtr& yield_stmt);
void PrintFunctionDecorator(const FunctionPtr& func);
void PrintFunctionParam(const VarPtr& var);
void PrintFunctionSignature(const FunctionPtr& func);
void PrintFunctionReturnType(const FunctionPtr& func);
void PrintFunctionBody(const FunctionPtr& func);
bool TryPrintProgramCall(const CallPtr& op);
std::string FormatCallOpName(const std::string& op_name) const;
void PrintCallArgs(const CallPtr& op);
void PrintCallArg(const CallPtr& op, size_t index);
void PrintCallKwargs(const CallPtr& op, bool need_comma);
void PrintKwargValue(const std::string& key, const std::any& value);
void PrintBinaryOp(const BinaryExprPtr& op, const char* op_symbol);
void PrintFunctionBinaryOp(const BinaryExprPtr& op, const char* func_name);
void PrintChild(const ExprPtr& parent, const ExprPtr& child, bool is_left);
void PrintPrefixUnaryOp(const UnaryExprPtr& op, const char* op_symbol, Precedence operand_precedence);
void PrintShapedTypePrefix(
std::ostringstream& oss, const char* type_name, const std::vector<ExprPtr>& shape, DataType dtype);
void PrintShapeDims(std::ostringstream& oss, const std::vector<ExprPtr>& shape);
void PrintExprList(std::ostringstream& oss, const std::vector<ExprPtr>& exprs);
std::string PrintMemRef(const MemRef& memref);
std::string PrintTileView(const TileView& tile_view);
std::string PrintHardwareInfo(const HardwareInfo& hw_info);
std::string PrintTensorView(const TensorView& tensor_view);
};
std::string FormatFloatLiteral(double value)
{
if (std::fabs(value - std::floor(value)) < 1e-15) {
std::ostringstream oss;
oss << std::fixed << std::setprecision(1) << value;
return oss.str();
} else {
std::ostringstream oss;
oss << value;
return oss.str();
}
}
std::string IRPythonPrinter::Print(const IRNodePtr& node)
{
stream_.str("");
stream_.clear();
indent_level_ = 0;
PrintIRNodeWithVisitor(*this, stream_, node);
return stream_.str();
}
void IRPythonPrinter::PrintShapedTypePrefix(
std::ostringstream& oss, const char* type_name, const std::vector<ExprPtr>& shape, DataType dtype)
{
oss << prefix_ << "." << type_name << "[[";
PrintShapeDims(oss, shape);
oss << "], " << prefix_ << "." << BlockDslDTypeToString(dtype);
}
std::string IRPythonPrinter::Print(const TypePtr& type)
{
if (auto scalar_type = As<ScalarType>(type)) {
return prefix_ + ".Scalar[" + prefix_ + "." + BlockDslDTypeToString(scalar_type->dtype_) + "]";
}
if (auto ptr_type = As<PtrType>(type)) {
return prefix_ + ".Ptr[" + prefix_ + "." + BlockDslDTypeToString(ptr_type->dtype_) + "]";
}
if (auto tensor_type = As<TensorType>(type)) {
std::ostringstream oss;
PrintShapedTypePrefix(oss, "Tensor", tensor_type->shape_, tensor_type->dtype_);
if (tensor_type->tensor_view_.has_value()) {
oss << ", tensor_view=" << PrintTensorView(tensor_type->tensor_view_.value());
}
if (tensor_type->memref_.has_value()) {
oss << ", " << PrintMemRef(*tensor_type->memref_.value());
}
oss << "]";
return oss.str();
}
if (auto tile_type = As<TileType>(type)) {
std::ostringstream oss;
PrintShapedTypePrefix(oss, "Tile", tile_type->shape_, tile_type->dtype_);
if (tile_type->tileView_.has_value()) {
oss << ", tile_view=" << PrintTileView(tile_type->tileView_.value());
}
if (tile_type->hardwareInfo_.has_value()) {
oss << ", hardware_info=" << PrintHardwareInfo(tile_type->hardwareInfo_.value());
}
if (tile_type->memref_.has_value()) {
oss << ", " << PrintMemRef(*tile_type->memref_.value());
}
oss << "]";
return oss.str();
}
if (auto tuple_type = As<TupleType>(type)) {
std::ostringstream oss;
oss << prefix_ << ".Tuple([";
for (size_t i = 0; i < tuple_type->types_.size(); ++i) {
if (i > 0)
oss << ", ";
oss << Print(tuple_type->types_[i]);
}
oss << "])";
return oss.str();
}
if (auto memref_type = As<MemRefType>(type)) {
return prefix_ + ".MemRefType";
}
return prefix_ + ".UnknownType";
}
std::string IRPythonPrinter::GetIndent() const { return std::string(static_cast<size_t>(indent_level_ * 4), ' '); }
void IRPythonPrinter::IncreaseIndent() { indent_level_++; }
void IRPythonPrinter::DecreaseIndent()
{
if (indent_level_ > 0) {
indent_level_--;
}
}
void IRPythonPrinter::VisitExpr_(const VarPtr& op) { stream_ << op->name_; }
void IRPythonPrinter::VisitExpr_(const MemRefPtr& op) { stream_ << op->name_; }
void IRPythonPrinter::VisitExpr_(const ConstIntPtr& op)
{
if (op->dtype() == DataType::INT64 || op->dtype() == DataType::INDEX) {
stream_ << op->value_;
} else {
stream_ << prefix_ << ".const(" << op->value_ << ", " << prefix_ << "." << BlockDslDTypeToString(op->dtype())
<< ")";
}
}
void IRPythonPrinter::VisitExpr_(const ConstFloatPtr& op)
{
if (op->dtype() != DataType::FP32) {
stream_ << prefix_ << ".const(" << FormatFloatLiteral(op->value_) << ", " << prefix_ << "."
<< BlockDslDTypeToString(op->dtype()) << ")";
} else {
stream_ << FormatFloatLiteral(op->value_);
}
}
void IRPythonPrinter::VisitExpr_(const ConstBoolPtr& op) { stream_ << (op->value_ ? "True" : "False"); }
bool IRPythonPrinter::TryPrintProgramCall(const CallPtr& op)
{
if (current_program_) {
if (current_program_->GetFunction(op->name_)) {
stream_ << "self." << op->name_ << "(";
for (size_t i = 0; i < op->args_.size(); ++i) {
if (i > 0)
stream_ << ", ";
VisitExpr(op->args_[i]);
}
stream_ << ")";
return true;
}
}
return false;
}
std::string IRPythonPrinter::FormatCallOpName(const std::string& op_name) const
{
if (op_name.find('.') == std::string::npos) {
return op_name;
}
std::string normalized_op_name = op_name;
size_t scalar_pos = normalized_op_name.find("_scalar");
if (scalar_pos != std::string::npos) {
normalized_op_name = normalized_op_name.substr(0, scalar_pos);
}
return prefix_ + "." + normalized_op_name;
}
void IRPythonPrinter::PrintCallArg(const CallPtr& op, size_t index)
{
if (op->name_ == "block.alloc" && index == 0) {
if (auto const_int = std::dynamic_pointer_cast<const ConstInt>(op->args_[index])) {
int space_value = static_cast<int>(const_int->value_);
stream_ << prefix_ << ".MemorySpace." << MemorySpaceToString(static_cast<MemorySpace>(space_value));
return;
}
}
VisitExpr(op->args_[index]);
}
void IRPythonPrinter::PrintCallArgs(const CallPtr& op)
{
for (size_t i = 0; i < op->args_.size(); ++i) {
if (i > 0)
stream_ << ", ";
PrintCallArg(op, i);
}
}
void IRPythonPrinter::PrintKwargValue(const std::string& key, const std::any& value)
{
if (value.type() == typeid(int)) {
int int_val = AnyCast<int>(value, "printing kwarg: " + key);
if (key == "set_pipe" || key == "wait_pipe") {
stream_ << prefix_ << ".PipeType." << PipeTypeToString(static_cast<PipeType>(int_val));
} else {
stream_ << int_val;
}
} else if (value.type() == typeid(bool)) {
stream_ << (AnyCast<bool>(value, "printing kwarg: " + key) ? "True" : "False");
} else if (value.type() == typeid(std::string)) {
stream_ << "'" << AnyCast<std::string>(value, "printing kwarg: " + key) << "'";
} else if (value.type() == typeid(double)) {
stream_ << FormatFloatLiteral(AnyCast<double>(value, "printing kwarg: " + key));
} else if (value.type() == typeid(float)) {
stream_ << FormatFloatLiteral(static_cast<double>(AnyCast<float>(value, "printing kwarg: " + key)));
} else if (value.type() == typeid(DataType)) {
stream_ << prefix_ << "." << BlockDslDTypeToString(AnyCast<DataType>(value, "printing kwarg: " + key));
} else if (value.type() == typeid(MemorySpace)) {
stream_ << prefix_ << ".MemorySpace."
<< MemorySpaceToString(AnyCast<MemorySpace>(value, "printing kwarg: " + key));
} else if (value.type() == typeid(std::vector<int>)) {
const auto& vec = AnyCast<std::vector<int>>(value, "printing kwarg: " + key);
stream_ << "[";
for (size_t i = 0; i < vec.size(); ++i) {
if (i > 0)
stream_ << ", ";
stream_ << vec[i];
}
stream_ << "]";
} else {
CHECK(false) << "Invalid kwarg type for key: " << key
<< ", expected int, bool, std::string, double, float, DataType, MemorySpace, "
"or std::vector<int>, but got "
<< DemangleTypeName(value.type().name());
}
}
void IRPythonPrinter::PrintCallKwargs(const CallPtr& op, bool need_comma)
{
for (const auto& [key, value] : op->kwargs_) {
if (need_comma) {
stream_ << ", ";
}
need_comma = true;
stream_ << key << "=";
PrintKwargValue(key, value);
}
}
void IRPythonPrinter::VisitExpr_(const CallPtr& op)
{
INTERNAL_CHECK(!op->name_.empty()) << "Call has empty name";
if (TryPrintProgramCall(op)) {
return;
}
stream_ << FormatCallOpName(op->name_) << "(";
PrintCallArgs(op);
PrintCallKwargs(op, !op->args_.empty());
stream_ << ")";
}
void IRPythonPrinter::VisitExpr_(const MakeTuplePtr& op)
{
stream_ << "[";
for (size_t i = 0; i < op->elements_.size(); ++i) {
if (i > 0)
stream_ << ", ";
VisitExpr(op->elements_[i]);
}
stream_ << "]";
}
void IRPythonPrinter::VisitExpr_(const GetItemExprPtr& op)
{
VisitExpr(op->value_);
stream_ << "[";
VisitExpr(op->slice_);
stream_ << "]";
}
void IRPythonPrinter::PrintChild(const ExprPtr& parent, const ExprPtr& child, bool is_left)
{
PrintChildExprWithParens(*this, stream_, parent, child, is_left);
}
void IRPythonPrinter::PrintBinaryOp(const BinaryExprPtr& op, const char* op_symbol)
{
PrintChild(op, op->left_, true);
stream_ << " " << op_symbol << " ";
PrintChild(op, op->right_, false);
}
void IRPythonPrinter::PrintFunctionBinaryOp(const BinaryExprPtr& op, const char* func_name)
{
stream_ << func_name << "(";
VisitExpr(op->left_);
stream_ << ", ";
VisitExpr(op->right_);
stream_ << ")";
}
void IRPythonPrinter::PrintPrefixUnaryOp(const UnaryExprPtr& op, const char* op_symbol, Precedence operand_precedence)
{
stream_ << op_symbol;
Precedence operand_prec = GetPrecedence(op->operand_);
if (operand_prec < operand_precedence) {
stream_ << "(";
VisitExpr(op->operand_);
stream_ << ")";
} else {
VisitExpr(op->operand_);
}
}
void IRPythonPrinter::VisitExpr_(const AddPtr& op) { PrintBinaryOp(op, "+"); }
void IRPythonPrinter::VisitExpr_(const SubPtr& op) { PrintBinaryOp(op, "-"); }
void IRPythonPrinter::VisitExpr_(const MulPtr& op) { PrintBinaryOp(op, "*"); }
void IRPythonPrinter::VisitExpr_(const FloorDivPtr& op) { PrintBinaryOp(op, "//"); }
void IRPythonPrinter::VisitExpr_(const FloorModPtr& op) { PrintBinaryOp(op, "%"); }
void IRPythonPrinter::VisitExpr_(const FloatDivPtr& op) { PrintBinaryOp(op, "/"); }
void IRPythonPrinter::VisitExpr_(const PowPtr& op) { PrintBinaryOp(op, "**"); }
void IRPythonPrinter::VisitExpr_(const MinPtr& op) { PrintFunctionBinaryOp(op, "min"); }
void IRPythonPrinter::VisitExpr_(const MaxPtr& op) { PrintFunctionBinaryOp(op, "max"); }
void IRPythonPrinter::VisitExpr_(const EqPtr& op) { PrintBinaryOp(op, "=="); }
void IRPythonPrinter::VisitExpr_(const NePtr& op) { PrintBinaryOp(op, "!="); }
void IRPythonPrinter::VisitExpr_(const LtPtr& op) { PrintBinaryOp(op, "<"); }
void IRPythonPrinter::VisitExpr_(const LePtr& op) { PrintBinaryOp(op, "<="); }
void IRPythonPrinter::VisitExpr_(const GtPtr& op) { PrintBinaryOp(op, ">"); }
void IRPythonPrinter::VisitExpr_(const GePtr& op) { PrintBinaryOp(op, ">="); }
void IRPythonPrinter::VisitExpr_(const AndPtr& op) { PrintBinaryOp(op, "and"); }
void IRPythonPrinter::VisitExpr_(const OrPtr& op) { PrintBinaryOp(op, "or"); }
void IRPythonPrinter::VisitExpr_(const XorPtr& op) { PrintBinaryOp(op, "xor"); }
void IRPythonPrinter::VisitExpr_(const BitAndPtr& op) { PrintBinaryOp(op, "&"); }
void IRPythonPrinter::VisitExpr_(const BitOrPtr& op) { PrintBinaryOp(op, "|"); }
void IRPythonPrinter::VisitExpr_(const BitXorPtr& op) { PrintBinaryOp(op, "^"); }
void IRPythonPrinter::VisitExpr_(const BitShiftLeftPtr& op) { PrintBinaryOp(op, "<<"); }
void IRPythonPrinter::VisitExpr_(const BitShiftRightPtr& op) { PrintBinaryOp(op, ">>"); }
void IRPythonPrinter::VisitExpr_(const NegPtr& op) { PrintPrefixUnaryOp(op, "-", Precedence::kUnary); }
void IRPythonPrinter::VisitExpr_(const AbsPtr& op)
{
stream_ << "abs(";
VisitExpr(op->operand_);
stream_ << ")";
}
void IRPythonPrinter::VisitExpr_(const CastPtr& op)
{
auto scalar_type = As<ScalarType>(op->GetType());
INTERNAL_CHECK(scalar_type) << "Cast has non-scalar type";
stream_ << prefix_ << ".cast(";
VisitExpr(op->operand_);
stream_ << ", " << prefix_ << "." << BlockDslDTypeToString(scalar_type->dtype_) << ")";
}
void IRPythonPrinter::VisitExpr_(const NotPtr& op) { PrintPrefixUnaryOp(op, "not ", Precedence::kNot); }
void IRPythonPrinter::VisitExpr_(const BitNotPtr& op) { PrintPrefixUnaryOp(op, "~", Precedence::kUnary); }
void IRPythonPrinter::VisitStmt_(const AssignStmtPtr& op)
{
VisitExpr(op->var_);
stream_ << ": " << Print(op->var_->GetType()) << " = ";
VisitExpr(op->value_);
}
void IRPythonPrinter::VisitStmt_(const IfStmtPtr& op)
{
stream_ << "if ";
VisitExpr(op->condition_);
stream_ << ":\n";
IncreaseIndent();
VisitStmtBody(op->thenBody_, op->returnVars_);
DecreaseIndent();
if (op->elseBody_.has_value()) {
stream_ << "\n" << GetIndent() << "else:\n";
IncreaseIndent();
VisitStmtBody(*op->elseBody_, op->returnVars_);
DecreaseIndent();
}
}
void IRPythonPrinter::VisitStmt_(const YieldStmtPtr& op)
{
stream_ << prefix_ << ".yield_(";
for (size_t i = 0; i < op->value_.size(); ++i) {
if (i > 0)
stream_ << ", ";
VisitExpr(op->value_[i]);
}
stream_ << ")";
}
void IRPythonPrinter::VisitStmt_(const ReturnStmtPtr& op) { PrintReturnStmtValues(*this, stream_, op->value_); }
void IRPythonPrinter::VisitStmt_(const ForStmtPtr& op)
{
PrintPythonRangeHeader(stream_, prefix_, op, [this](const ExprPtr& expr) { VisitExpr(expr); });
IncreaseIndent();
VisitStmtBody(op->body_, op->returnVars_);
DecreaseIndent();
}
void IRPythonPrinter::VisitStmt_(const WhileStmtPtr& op)
{
if (op->iterArgs_.empty()) {
stream_ << "while ";
VisitExpr(op->condition_);
stream_ << ":\n";
IncreaseIndent();
VisitStmtBody(op->body_, op->returnVars_);
DecreaseIndent();
} else {
PrintPythonWhileHeader(stream_, prefix_, op, [this](const ExprPtr& expr) { VisitExpr(expr); });
IncreaseIndent();
stream_ << GetIndent() << prefix_ << ".cond(";
VisitExpr(op->condition_);
stream_ << ")\n";
VisitStmtBody(op->body_, op->returnVars_);
DecreaseIndent();
}
}
void IRPythonPrinter::VisitStmt_(const SectionStmtPtr& op)
{
static const std::unordered_map<SectionKind, std::string> section_kind_to_dsl = {
{SectionKind::Vector, "section_vector"},
{SectionKind::Cube, "section_cube"},
};
auto it = section_kind_to_dsl.find(op->sectionKind_);
INTERNAL_CHECK(it != section_kind_to_dsl.end())
<< "Internal error: Unknown SectionKind in python_printer: " << SectionKindToString(op->sectionKind_);
stream_ << "with " << prefix_ << "." << it->second << "():\n";
IncreaseIndent();
PrintStmtBlock(op->body_);
DecreaseIndent();
}
void IRPythonPrinter::VisitStmt_(const SeqStmtsPtr& op)
{
for (size_t i = 0; i < op->stmts_.size(); ++i) {
PrintStmtBlock(op->stmts_[i]);
if (i < op->stmts_.size() - 1) {
stream_ << "\n";
}
}
}
void IRPythonPrinter::PrintStmtBlock(const StmtPtr& stmt)
{
if (auto seq = As<SeqStmts>(stmt)) {
for (size_t i = 0; i < seq->stmts_.size(); ++i) {
PrintStmtBlock(seq->stmts_[i]);
if (i < seq->stmts_.size() - 1)
stream_ << "\n";
}
} else {
stream_ << GetIndent();
VisitStmt(stmt);
}
}
void IRPythonPrinter::VisitStmt_(const EvalStmtPtr& op)
{
VisitExpr(op->expr_);
}
void IRPythonPrinter::VisitStmt_(const BreakStmtPtr& ) { stream_ << "break"; }
void IRPythonPrinter::VisitStmt_(const ContinueStmtPtr& ) { stream_ << "continue"; }
void IRPythonPrinter::VisitStmt_(const StmtPtr& op) { stream_ << op->TypeName(); }
void IRPythonPrinter::PrintYieldAssignmentVars(const std::vector<VarPtr>& return_vars)
{
if (return_vars.size() == 1) {
stream_ << return_vars[0]->name_ << ": " << Print(return_vars[0]->GetType());
} else {
for (size_t i = 0; i < return_vars.size(); ++i) {
if (i > 0)
stream_ << ", ";
stream_ << return_vars[i]->name_;
}
}
}
void IRPythonPrinter::PrintReturnFromYield(const YieldStmtPtr& yield_stmt)
{
stream_ << GetIndent() << "return";
if (!yield_stmt->value_.empty()) {
stream_ << " ";
for (size_t i = 0; i < yield_stmt->value_.size(); ++i) {
if (i > 0)
stream_ << ", ";
VisitExpr(yield_stmt->value_[i]);
}
}
}
void IRPythonPrinter::PrintYieldAssignment(const YieldStmtPtr& yield_stmt, const std::vector<VarPtr>& return_vars)
{
stream_ << GetIndent();
PrintYieldAssignmentVars(return_vars);
stream_ << " = " << prefix_ << ".yield_(";
for (size_t i = 0; i < yield_stmt->value_.size(); ++i) {
if (i > 0)
stream_ << ", ";
VisitExpr(yield_stmt->value_[i]);
}
stream_ << ")";
}
void IRPythonPrinter::VisitSeqStmtBody(const SeqStmtsPtr& seq_stmts, const std::vector<VarPtr>& return_vars)
{
for (size_t i = 0; i < seq_stmts->stmts_.size(); ++i) {
auto stmt = seq_stmts->stmts_[i];
bool is_last = (i == seq_stmts->stmts_.size() - 1);
auto inner_yield = As<YieldStmt>(stmt);
if (inner_yield && is_last && !inner_yield->value_.empty() && !return_vars.empty()) {
PrintYieldAssignment(inner_yield, return_vars);
} else if (inner_yield) {
stream_ << GetIndent();
VisitStmt(stmt);
} else {
PrintStmtBlock(stmt);
}
if (i < seq_stmts->stmts_.size() - 1) {
stream_ << "\n";
}
}
}
void IRPythonPrinter::VisitStmtBody(const StmtPtr& body, const std::vector<VarPtr>& return_vars)
{
if (auto yield_stmt = As<YieldStmt>(body)) {
if (!yield_stmt->value_.empty() && !return_vars.empty()) {
PrintYieldAssignment(yield_stmt, return_vars);
} else {
stream_ << GetIndent();
VisitStmt(yield_stmt);
}
} else if (auto seq_stmts = AsMut<SeqStmts>(body)) {
VisitSeqStmtBody(seq_stmts, return_vars);
} else {
PrintStmtBlock(body);
}
}
void IRPythonPrinter::PrintFunctionDecorator(const FunctionPtr& func)
{
stream_ << GetIndent() << "@" << prefix_ << ".function";
if (func->funcType_ != FunctionType::OPAQUE) {
stream_ << "(type=" << prefix_ << ".FunctionType." << FunctionTypeToString(func->funcType_) << ")";
}
stream_ << "\n";
}
void IRPythonPrinter::PrintFunctionParam(const VarPtr& var)
{
stream_ << var->name_ << ": ";
stream_ << Print(var->GetType());
}
void IRPythonPrinter::PrintFunctionSignature(const FunctionPtr& func)
{
stream_ << GetIndent() << "def " << func->name_ << "(";
if (current_program_) {
stream_ << "self";
}
for (size_t i = 0; i < func->params_.size(); ++i) {
if (i > 0 || current_program_)
stream_ << ", ";
PrintFunctionParam(func->params_[i]);
}
stream_ << ")";
}
void IRPythonPrinter::PrintFunctionReturnType(const FunctionPtr& func)
{
PrintFunctionReturnAnnotation(stream_, func->returnTypes_, [this](const TypePtr& type) { return Print(type); });
}
void IRPythonPrinter::PrintFunctionBody(const FunctionPtr& func)
{
if (!func->body_) {
return;
}
const auto& seq_stmts = func->body_;
for (size_t i = 0; i < seq_stmts->stmts_.size(); ++i) {
if (auto yield_stmt = As<YieldStmt>(seq_stmts->stmts_[i])) {
PrintReturnFromYield(yield_stmt);
} else {
PrintStmtBlock(seq_stmts->stmts_[i]);
}
if (i + 1 < seq_stmts->stmts_.size()) {
stream_ << "\n";
}
}
}
void IRPythonPrinter::VisitFunction(const FunctionPtr& func)
{
PrintFunctionDecorator(func);
PrintFunctionSignature(func);
PrintFunctionReturnType(func);
stream_ << ":\n";
IncreaseIndent();
PrintFunctionBody(func);
DecreaseIndent();
}
class FuncCallCollector : public IRVisitor {
using IRVisitor::VisitExpr_;
using IRVisitor::VisitStmt_;
public:
std::set<std::string> collected_func_names;
void VisitExpr_(const CallPtr& op) override
{
INTERNAL_CHECK(!op->name_.empty()) << "Call has empty name";
collected_func_names.insert(op->name_);
IRVisitor::VisitExpr_(op);
}
};
static std::vector<std::pair<std::string, FunctionPtr>> TopologicalSortFunctions(
const std::map<std::string, FunctionPtr>& functions)
{
std::map<std::string, std::set<std::string>> dependencies;
for (const auto& [name, func] : functions) {
FuncCallCollector collector;
if (func->body_) {
collector.VisitStmt(func->body_);
}
for (const auto& called_name : collector.collected_func_names) {
if (functions.count(called_name) > 0) {
dependencies[name].insert(called_name);
}
}
}
std::vector<std::pair<std::string, FunctionPtr>> sorted;
std::set<std::string> visited;
std::set<std::string> in_progress;
std::function<bool(const std::string&)> dfs = [&](const std::string& name) -> bool {
if (visited.count(name))
return true;
if (in_progress.count(name))
return false;
in_progress.insert(name);
if (dependencies.count(name)) {
for (const auto& dep : dependencies[name]) {
if (!dfs(dep))
return false;
}
}
in_progress.erase(name);
visited.insert(name);
sorted.emplace_back(name, functions.at(name));
return true;
};
for (const auto& function_entry : functions) {
const auto& name = function_entry.first;
if (!dfs(name)) {
sorted.clear();
for (const auto& fallback_entry : functions) {
sorted.emplace_back(fallback_entry);
}
return sorted;
}
}
return sorted;
}
void IRPythonPrinter::VisitProgram(const ProgramPtr& program)
{
stream_ << "# pypto_block.program: " << (program->name_.empty() ? "Program" : program->name_) << "\n";
if (prefix_ == "pl") {
stream_ << "import pypto_block.language as pl\n\n";
} else {
stream_ << "import pypto_block.language as " << prefix_ << "\n\n";
}
stream_ << "@" << prefix_ << ".program\n";
stream_ << "class " << (program->name_.empty() ? "Program" : program->name_) << ":\n";
IncreaseIndent();
auto sorted_functions = TopologicalSortFunctions(program->functions_);
auto prev_program = current_program_;
current_program_ = program;
bool first = true;
for (const auto& entry : sorted_functions) {
const auto& func = entry.second;
if (!first) {
stream_ << "\n";
}
first = false;
VisitFunction(func);
}
current_program_ = prev_program;
DecreaseIndent();
}
void IRPythonPrinter::PrintShapeDims(std::ostringstream& oss, const std::vector<ExprPtr>& shape)
{
for (size_t i = 0; i < shape.size(); ++i) {
if (i > 0)
oss << ", ";
if (auto const_int = As<ConstInt>(shape[i])) {
oss << const_int->value_;
} else {
IRPythonPrinter temp_printer(prefix_);
oss << temp_printer.Print(shape[i]);
}
}
}
void IRPythonPrinter::PrintExprList(std::ostringstream& oss, const std::vector<ExprPtr>& exprs)
{
for (size_t i = 0; i < exprs.size(); ++i) {
if (i > 0)
oss << ", ";
IRPythonPrinter temp_printer(prefix_);
oss << temp_printer.Print(exprs[i]);
}
}
std::string IRPythonPrinter::PrintMemRef(const MemRef& memref)
{
std::ostringstream oss;
oss << prefix_ << ".MemRef(" << prefix_ << ".MemorySpace." << MemorySpaceToString(memref.memorySpace_) << ", ";
IRPythonPrinter temp_printer(prefix_);
oss << temp_printer.Print(memref.addr_);
oss << ", " << memref.size_ << ")";
return oss.str();
}
std::string IRPythonPrinter::PrintTileView(const TileView& tile_view)
{
std::ostringstream oss;
oss << prefix_ << ".TileView(valid_shape=[";
PrintExprList(oss, tile_view.validShape);
oss << "], stride=[";
PrintExprList(oss, tile_view.stride);
oss << "], start_offset=";
{
IRPythonPrinter temp_printer(prefix_);
oss << temp_printer.Print(tile_view.startOffset);
}
oss << ")";
return oss.str();
}
std::string IRPythonPrinter::PrintHardwareInfo(const HardwareInfo& hw_info)
{
std::ostringstream oss;
oss << prefix_ << ".HardwareInfo(";
oss << "blayout=" << prefix_ << ".TileLayout." << TileLayoutToPythonName(hw_info.blayout);
oss << ", slayout=" << prefix_ << ".TileLayout." << TileLayoutToPythonName(hw_info.slayout);
oss << ", fractal=" << hw_info.fractal;
oss << ", pad=" << prefix_ << ".TilePad." << TilePadToPythonName(hw_info.pad);
oss << ", compact=" << prefix_ << ".CompactMode." << CompactModeToPythonName(hw_info.compact);
oss << ")";
return oss.str();
}
std::string IRPythonPrinter::PrintTensorView(const TensorView& tensor_view)
{
std::ostringstream oss;
oss << prefix_ << ".TensorView(";
if (!tensor_view.validShape.empty()) {
oss << "valid_shape=[";
for (size_t i = 0; i < tensor_view.validShape.size(); ++i) {
if (i > 0)
oss << ", ";
IRPythonPrinter temp_printer(prefix_);
oss << temp_printer.Print(tensor_view.validShape[i]);
}
oss << "], ";
}
oss << "stride=[";
for (size_t i = 0; i < tensor_view.stride.size(); ++i) {
if (i > 0)
oss << ", ";
IRPythonPrinter temp_printer(prefix_);
oss << temp_printer.Print(tensor_view.stride[i]);
}
oss << "], layout=" << prefix_ << ".TensorLayout.";
switch (tensor_view.layout) {
case TensorLayout::ND:
oss << "ND";
break;
case TensorLayout::DN:
oss << "DN";
break;
case TensorLayout::NZ:
oss << "NZ";
break;
default:
INTERNAL_CHECK(false) << "Unknown TensorLayout in PrintTensorView";
break;
}
oss << ")";
return oss.str();
}
std::string PythonDslPrint(const IRNodePtr& node, const std::string& prefix)
{
IRPythonPrinter printer(prefix);
return printer.Print(node);
}
std::string PythonDslPrint(const TypePtr& type, const std::string& prefix)
{
IRPythonPrinter printer(prefix);
return printer.Print(type);
}
}
}