/**
 * Copyright (c) 2026 Huawei Technologies Co., Ltd.
 * This program is free software, you can redistribute it and/or modify it under the terms and conditions of
 * CANN Open Software License Agreement Version 2.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

#include <cmath>
#include <cstddef>
#include <iomanip>
#include <ios>
#include <optional>
#include <sstream>
#include <string>
#include <typeindex>
#include <unordered_map>
#include <utility>
#include <vector>

#include "core/dtype.h"
#include "core/logging.h"
#include "ir/core.h"
#include "ir/type.h"
#include "ir/expr.h"
#include "ir/function.h"
#include "ir/memref.h"
#include "ir/stmt.h"
#include "ir/kind_traits.h"
#include "ir/scalar_expr.h"
#include "ir/transforms/base/visitor.h"
#include "ir/transforms/printer.h"

#include "tilefwk/symbolic_scalar.h"
#include "interface/tensor/ir.h"

using npu::tile_fwk::SymbolicScalar;

namespace pypto {
namespace ir {

Precedence GetPrecedence(const ExprPtr& expr)
{
    static const std::unordered_map<std::type_index, Precedence> kPrecedenceMap = {
        // Logical operators
        {std::type_index(typeid(Or)), Precedence::kOr},
        {std::type_index(typeid(Xor)), Precedence::kXor},
        {std::type_index(typeid(And)), Precedence::kAnd},
        {std::type_index(typeid(Not)), Precedence::kNot},

        // Comparison operators
        {std::type_index(typeid(Eq)), Precedence::kComparison},
        {std::type_index(typeid(Ne)), Precedence::kComparison},
        {std::type_index(typeid(Lt)), Precedence::kComparison},
        {std::type_index(typeid(Le)), Precedence::kComparison},
        {std::type_index(typeid(Gt)), Precedence::kComparison},
        {std::type_index(typeid(Ge)), Precedence::kComparison},

        // Bitwise operators
        {std::type_index(typeid(BitOr)), Precedence::kBitOr},
        {std::type_index(typeid(BitXor)), Precedence::kBitXor},
        {std::type_index(typeid(BitAnd)), Precedence::kBitAnd},
        {std::type_index(typeid(BitShiftLeft)), Precedence::kBitShift},
        {std::type_index(typeid(BitShiftRight)), Precedence::kBitShift},

        // Arithmetic operators
        {std::type_index(typeid(Add)), Precedence::kAddSub},
        {std::type_index(typeid(Sub)), Precedence::kAddSub},
        {std::type_index(typeid(Mul)), Precedence::kMulDivMod},
        {std::type_index(typeid(FloorDiv)), Precedence::kMulDivMod},
        {std::type_index(typeid(FloatDiv)), Precedence::kMulDivMod},
        {std::type_index(typeid(FloorMod)), Precedence::kMulDivMod},
        {std::type_index(typeid(Pow)), Precedence::kPow},

        // Unary operators
        {std::type_index(typeid(Neg)), Precedence::kUnary},
        {std::type_index(typeid(BitNot)), Precedence::kUnary},

        // Function-like operators and atoms
        {std::type_index(typeid(Abs)), Precedence::kCall},
        {std::type_index(typeid(Cast)), Precedence::kCall},
        {std::type_index(typeid(Min)), Precedence::kCall},
        {std::type_index(typeid(Max)), Precedence::kCall},
        {std::type_index(typeid(Call)), Precedence::kCall},
        {std::type_index(typeid(Var)), Precedence::kAtom},
        {std::type_index(typeid(ConstInt)), Precedence::kAtom},
        {std::type_index(typeid(ConstFloat)), Precedence::kAtom},
        {std::type_index(typeid(ConstBool)), Precedence::kAtom},
        {std::type_index(typeid(ScalarExpr)), Precedence::kAtom},
        {std::type_index(typeid(GetItemExpr)), Precedence::kAtom},
    };

    INTERNAL_CHECK(expr) << "Expression is null";
    const Expr& expr_ref = *expr;
    const auto it = kPrecedenceMap.find(std::type_index(typeid(expr_ref)));
    if (it != kPrecedenceMap.end()) {
        return it->second;
    }

    // Default for any other expression types.
    return Precedence::kAtom;
}

bool IsRightAssociative(const ExprPtr& expr)
{
    // Only ** (power) is right-associative in Python
    return IsA<Pow>(expr);
}

bool NeedsParensForPrint(const ExprPtr& parent, const ExprPtr& child, bool is_left)
{
    Precedence parent_prec = GetPrecedence(parent);
    Precedence child_prec = GetPrecedence(child);
    if (child_prec < parent_prec) {
        return true;
    }

    if (child_prec == parent_prec) {
        if (IsRightAssociative(parent)) {
            return is_left;
        } else {
            return !is_left;
        }
    }
    return false;
}

void PrintIRNodeWithVisitor(IRVisitor& visitor, std::ostream& stream, const IRNodePtr& node)
{
    if (auto program = As<Program>(node)) {
        visitor.VisitProgram(program);
    } else if (auto func = As<Function>(node)) {
        visitor.VisitFunction(func);
    } else if (auto stmt = As<Stmt>(node)) {
        visitor.VisitStmt(stmt);
    } else if (auto expr = As<Expr>(node)) {
        visitor.VisitExpr(expr);
    } else {
        stream << "<unsupported IRNode type>";
    }
}

void PrintChildExprWithParens(
    IRVisitor& visitor, std::ostream& stream, const ExprPtr& parent, const ExprPtr& child, bool is_left)
{
    bool needs_parens = NeedsParensForPrint(parent, child, is_left);
    if (needs_parens) {
        stream << "(";
    }

    visitor.VisitExpr(child);

    if (needs_parens) {
        stream << ")";
    }
}

void PrintReturnStmtValues(IRVisitor& visitor, std::ostream& stream, const std::vector<ExprPtr>& values)
{
    stream << "return";
    if (!values.empty()) {
        stream << " ";
        for (size_t i = 0; i < values.size(); ++i) {
            if (i > 0)
                stream << ", ";
            visitor.VisitExpr(values[i]);
        }
    }
}

void PrintFunctionReturnAnnotation(
    std::ostream& stream, const std::vector<TypePtr>& return_types,
    const std::function<std::string(const TypePtr&)>& print_type)
{
    if (!return_types.empty()) {
        stream << " -> ";
        if (return_types.size() == 1) {
            stream << print_type(return_types[0]);
        } else {
            stream << "tuple[";
            for (size_t i = 0; i < return_types.size(); ++i) {
                if (i > 0)
                    stream << ", ";
                stream << print_type(return_types[i]);
            }
            stream << "]";
        }
    }
}

namespace {

std::string FormatFloatLiteral(double value)
{
    // Check if the value is an integer (no fractional part)
    if (std::fabs(value) - std::floor(value) < 1e-10) {
        // For integer values, format as X.0
        std::ostringstream oss;
        oss << std::fixed << std::setprecision(1) << value;
        return oss.str();
    } else {
        // For non-integer values, use default formatting with enough precision
        std::ostringstream oss;
        oss << value;
        return oss.str();
    }
}

void PrintIterArgNames(std::ostringstream& stream, const std::vector<IterArgPtr>& iter_args)
{
    stream << "(";
    for (size_t i = 0; i < iter_args.size(); ++i) {
        if (i > 0)
            stream << ", ";
        stream << iter_args[i]->iterVar_->name_;
    }
    if (iter_args.size() == 1) {
        stream << ",";
    }
    stream << ")";
}

template <typename VisitExprFn>
void PrintIterArgInitValues(
    std::ostringstream& stream, const std::vector<IterArgPtr>& iter_args, const VisitExprFn& visit_expr)
{
    stream << "init_values=(";
    for (size_t i = 0; i < iter_args.size(); ++i) {
        if (i > 0)
            stream << ", ";
        visit_expr(iter_args[i]->initValue_);
    }
    if (iter_args.size() == 1) {
        stream << ",";
    }
    stream << ")";
}

template <typename VisitExprFn>
void PrintForRangeHeader(
    std::ostringstream& stream, const std::string& prefix, const ForStmtPtr& op, const VisitExprFn& visit_expr)
{
    stream << "for " << op->loopVar_->name_;
    if (!op->iterArgs_.empty()) {
        stream << ", ";
        PrintIterArgNames(stream, op->iterArgs_);
    }

    stream << " in " << prefix << ".range(";
    visit_expr(op->start_);
    stream << ", ";
    visit_expr(op->stop_);
    stream << ", ";
    visit_expr(op->step_);

    if (!op->iterArgs_.empty()) {
        stream << ", ";
        PrintIterArgInitValues(stream, op->iterArgs_, visit_expr);
    }
    stream << "):\n";
}

template <typename VisitExprFn>
void PrintWhileIterArgsHeader(
    std::ostringstream& stream, const std::string& prefix, const WhileStmtPtr& op, const VisitExprFn& visit_expr)
{
    stream << "for ";
    PrintIterArgNames(stream, op->iterArgs_);
    stream << " in " << prefix << ".while_(";
    PrintIterArgInitValues(stream, op->iterArgs_, visit_expr);
    stream << "):\n";
}
} // namespace

/**
 * \brief Python-style IR printer
 *
 * Prints IR nodes in Python syntax with type annotations and SSA-style control flow.
 * This is the recommended printer for new code that outputs valid Python syntax.
 *
 * Key features:
 * - Type annotations (e.g., x: pl.INT64, a: pl.Tensor[[4, 8], pl.FP32])
 * - SSA-style if/for with pl.yield_() and pl.range()
 * - Op attributes as keyword arguments
 * - Program headers with # pypto.program: name
 */
class IRPrinter : public IRVisitor {
    using IRVisitor::VisitExpr_;
    using IRVisitor::VisitStmt_;

public:
    explicit IRPrinter(std::string prefix = "ir", bool concise = false) : prefix_(std::move(prefix)), concise_(concise)
    {}
    ~IRPrinter() override = default;

    /**
     * \brief Print an IR node to a string in Python IR syntax
     *
     * \param node IR node to print (can be Expr, Stmt, Function, or Program)
     * \return Python-style string representation
     */
    std::string Print(const IRNodePtr& node);
    std::string Print(const TypePtr& type);

protected:
    PYPTO_IR_PRINTER_COMMON_VISITOR_OVERRIDES();
    void VisitExpr_(const ScalarExprPtr& op) override;
    void VisitStmt_(const TensorOpStmtPtr& op) override;
    void VisitStmt_(const ScalarOpStmtPtr& op) override;

private:
    std::ostringstream stream_;
    int indent_ = 0;
    std::string prefix_; // Prefix for type names (e.g., "pl" or "ir")
    bool concise_;       // When true, omit intermediate type annotations

    std::string GetIndent() const;

    // SeqStmts is a transparent container - recursed into without extra indent.
    void PrintStmtBlock(const StmtPtr& stmt);

    // Statement body visitor with SSA-style handling
    void VisitStmtBody(const StmtPtr& body, const std::vector<VarPtr>& return_vars = {});
    void VisitFunctionBody(const StmtPtr& body);
    void PrintYieldAssignmentVars(const std::vector<VarPtr>& return_vars);

    // Binary/unary operator helpers (reuse precedence logic)
    void PrintBinaryOp(const BinaryExprPtr& op, const char* op_symbol);
    void PrintFunctionBinaryOp(const BinaryExprPtr& op, const char* func_name);
    void PrintChild(const ExprPtr& parent, const ExprPtr& child, bool is_left);

    // Shape printing helper
    void PrintShapeDims(std::ostringstream& oss, const std::vector<ExprPtr>& shape);

    // Print an expression for use in type annotations (shapes, views).
    std::string PrintExprForType(const ExprPtr& expr);

    // MemRef and TileView printing helpers
    std::string PrintMemRef(const MemRef& memref);
};

// DataTypeToPythonString removed — now uses DTypeToString from dtype.h

// IRPrinter implementation
std::string IRPrinter::Print(const IRNodePtr& node)
{
    stream_.str("");
    stream_.clear();
    indent_ = 0;
    PrintIRNodeWithVisitor(*this, stream_, node);
    return stream_.str();
}

std::string IRPrinter::Print(const TypePtr& type)
{
    if (auto scalar_type = As<ScalarType>(type)) {
        // Print as pl.Scalar[pl.INT64] for proper round-trip support
        return prefix_ + ".Scalar[" + prefix_ + "." + DTypeToString(scalar_type->dtype_) + "]";
    }

    if (auto tensor_type = As<TensorType>(type)) {
        std::ostringstream oss;
        // Subscript-style: pl.Tensor[[shape], dtype]
        oss << prefix_ << ".Tensor[[";
        PrintShapeDims(oss, tensor_type->shape_);
        oss << "], " << prefix_ << "." << DTypeToString(tensor_type->dtype_);

        // Add optional memref as positional arg
        if (tensor_type->memref_.has_value()) {
            oss << ", " << PrintMemRef(*tensor_type->memref_.value());
        }

        oss << "]";
        return oss.str();
    }

    if (auto tile_type = As<TileType>(type)) {
        std::ostringstream oss;
        // Subscript-style: pl.Tile[[shape], dtype]
        oss << prefix_ << ".Tile[[";
        PrintShapeDims(oss, tile_type->shape_);
        oss << "], " << prefix_ << "." << DTypeToString(tile_type->dtype_);

        // Add optional memref as positional arg
        if (tile_type->memref_.has_value()) {
            oss << ", " << PrintMemRef(*tile_type->memref_.value());
        }

        oss << "]";
        return oss.str();
    }

    if (auto tuple_type = As<TupleType>(type)) {
        std::ostringstream oss;
        if (tuple_type->types_.empty()) {
            oss << prefix_ << ".Tuple[()]";
        } else {
            oss << prefix_ << ".Tuple[";
            for (size_t i = 0; i < tuple_type->types_.size(); ++i) {
                if (i > 0)
                    oss << ", ";
                oss << Print(tuple_type->types_[i]);
            }
            oss << "]";
        }
        return oss.str();
    }

    if (auto memref_type = As<MemRefType>(type)) {
        return prefix_ + ".MemRefType";
    }

    if (auto ptr_type = As<PtrType>(type)) {
        return prefix_ + ".Ptr";
    }

    if (auto token_type = As<TokenType>(type)) {
        return prefix_ + ".Token";
    }

    if (auto logical_tensor_type = As<LogicalTensorType>(type)) {
        return prefix_ + ".Tensor";
    }

    return prefix_ + ".Unknown";
}

std::string IRPrinter::GetIndent() const { return std::string(static_cast<size_t>(indent_ * 4), ' '); }

// Expression visitors - reuse precedence logic from base printer
void IRPrinter::VisitExpr_(const VarPtr& op) {
    if (auto type = As<LogicalTensorType>(op->GetType())) {
        stream_ << DumpTensorVar(op);
    } else {
        stream_ << op->name_;
    }
}

void IRPrinter::VisitExpr_(const MemRefPtr& op) { stream_ << PrintMemRef(*op); }

void IRPrinter::VisitExpr_(const ConstIntPtr& op) { stream_ << op->value_; }

void IRPrinter::VisitExpr_(const ConstFloatPtr& op) { stream_ << FormatFloatLiteral(op->value_); }

void IRPrinter::VisitExpr_(const ConstBoolPtr& op) { stream_ << (op->value_ ? "True" : "False"); }

void IRPrinter::VisitExpr_(const CallPtr& op)
{
    stream_ << prefix_ << ".call @" << op->name_ << "(";
    for (size_t i = 0; i < op->args_.size(); ++i) {
        if (i > 0)
            stream_ << ", ";
        VisitExpr(op->args_[i]);
    }
    stream_ << ")";
}

void IRPrinter::VisitExpr_(const MakeTuplePtr& op)
{
    stream_ << "[";
    for (size_t i = 0; i < op->elements_.size(); ++i) {
        if (i > 0)
            stream_ << ", ";
        VisitExpr(op->elements_[i]);
    }
    stream_ << "]";
}

void IRPrinter::VisitExpr_(const GetItemExprPtr& op)
{
    VisitExpr(op->value_);
    stream_ << "[";
    VisitExpr(op->slice_);
    stream_ << "]";
}

void IRPrinter::VisitExpr_(const ScalarExprPtr& op)
{
    auto scalar_type = As<ScalarType>(op->GetType());
    INTERNAL_CHECK_SPAN(scalar_type, op->span_) << "ScalarExpr has non-scalar type";
    stream_ << DumpScalarExpr(op);
}

// Binary and unary operators - reuse from base printer logic
void IRPrinter::PrintChild(const ExprPtr& parent, const ExprPtr& child, bool is_left)
{
    PrintChildExprWithParens(*this, stream_, parent, child, is_left);
}

void IRPrinter::PrintBinaryOp(const BinaryExprPtr& op, const char* op_symbol)
{
    PrintChild(op, op->left_, true);
    stream_ << " " << op_symbol << " ";
    PrintChild(op, op->right_, false);
}

void IRPrinter::PrintFunctionBinaryOp(const BinaryExprPtr& op, const char* func_name)
{
    stream_ << prefix_ << "." << func_name << "(";
    VisitExpr(op->left_);
    stream_ << ", ";
    VisitExpr(op->right_);
    stream_ << ")";
}

// Arithmetic binary operators
void IRPrinter::VisitExpr_(const AddPtr& op) { PrintBinaryOp(op, "+"); }
void IRPrinter::VisitExpr_(const SubPtr& op) { PrintBinaryOp(op, "-"); }
void IRPrinter::VisitExpr_(const MulPtr& op) { PrintBinaryOp(op, "*"); }
void IRPrinter::VisitExpr_(const FloorDivPtr& op) { PrintBinaryOp(op, "//"); }
void IRPrinter::VisitExpr_(const FloorModPtr& op) { PrintBinaryOp(op, "%"); }
void IRPrinter::VisitExpr_(const FloatDivPtr& op) { PrintBinaryOp(op, "/"); }
void IRPrinter::VisitExpr_(const PowPtr& op) { PrintBinaryOp(op, "**"); }

// Function-style binary operators
void IRPrinter::VisitExpr_(const MinPtr& op) { PrintFunctionBinaryOp(op, "min"); }
void IRPrinter::VisitExpr_(const MaxPtr& op) { PrintFunctionBinaryOp(op, "max"); }

// Comparison operators
void IRPrinter::VisitExpr_(const EqPtr& op) { PrintBinaryOp(op, "=="); }
void IRPrinter::VisitExpr_(const NePtr& op) { PrintBinaryOp(op, "!="); }
void IRPrinter::VisitExpr_(const LtPtr& op) { PrintBinaryOp(op, "<"); }
void IRPrinter::VisitExpr_(const LePtr& op) { PrintBinaryOp(op, "<="); }
void IRPrinter::VisitExpr_(const GtPtr& op) { PrintBinaryOp(op, ">"); }
void IRPrinter::VisitExpr_(const GePtr& op) { PrintBinaryOp(op, ">="); }

// Logical operators
void IRPrinter::VisitExpr_(const AndPtr& op) { PrintBinaryOp(op, "and"); }
void IRPrinter::VisitExpr_(const OrPtr& op) { PrintBinaryOp(op, "or"); }
void IRPrinter::VisitExpr_(const XorPtr& op) { PrintBinaryOp(op, "xor"); }

// Bitwise operators
void IRPrinter::VisitExpr_(const BitAndPtr& op) { PrintBinaryOp(op, "&"); }
void IRPrinter::VisitExpr_(const BitOrPtr& op) { PrintBinaryOp(op, "|"); }
void IRPrinter::VisitExpr_(const BitXorPtr& op) { PrintBinaryOp(op, "^"); }
void IRPrinter::VisitExpr_(const BitShiftLeftPtr& op) { PrintBinaryOp(op, "<<"); }
void IRPrinter::VisitExpr_(const BitShiftRightPtr& op) { PrintBinaryOp(op, ">>"); }

// Unary operators
void IRPrinter::VisitExpr_(const NegPtr& op)
{
    stream_ << "-";
    Precedence operand_prec = GetPrecedence(op->operand_);
    if (operand_prec < Precedence::kUnary) {
        stream_ << "(";
        VisitExpr(op->operand_);
        stream_ << ")";
    } else {
        VisitExpr(op->operand_);
    }
}

void IRPrinter::VisitExpr_(const AbsPtr& op)
{
    stream_ << prefix_ << ".abs(";
    VisitExpr(op->operand_);
    stream_ << ")";
}

void IRPrinter::VisitExpr_(const CastPtr& op)
{
    auto scalar_type = As<ScalarType>(op->GetType());
    INTERNAL_CHECK_SPAN(scalar_type, op->span_) << "Cast has non-scalar type";
    stream_ << prefix_ << ".cast(";
    VisitExpr(op->operand_);
    stream_ << ", " << prefix_ << "." << DTypeToString(scalar_type->dtype_) << ")";
}

void IRPrinter::VisitExpr_(const NotPtr& op)
{
    stream_ << "not ";
    Precedence operand_prec = GetPrecedence(op->operand_);
    if (operand_prec < Precedence::kNot) {
        stream_ << "(";
        VisitExpr(op->operand_);
        stream_ << ")";
    } else {
        VisitExpr(op->operand_);
    }
}

void IRPrinter::VisitExpr_(const BitNotPtr& op)
{
    stream_ << "~";
    Precedence operand_prec = GetPrecedence(op->operand_);
    if (operand_prec < Precedence::kUnary) {
        stream_ << "(";
        VisitExpr(op->operand_);
        stream_ << ")";
    } else {
        VisitExpr(op->operand_);
    }
}

// Statement visitors with proper Python syntax
void IRPrinter::VisitStmt_(const AssignStmtPtr& op)
{
    // Print with type annotation: var: type = value
    // In concise mode, omit the type annotation: var = value
    VisitExpr(op->var_);
    if (!concise_) {
        stream_ << ": " << Print(op->var_->GetType());
    }
    stream_ << " = ";
    VisitExpr(op->value_);
}

void IRPrinter::VisitStmt_(const IfStmtPtr& op)
{
    // SSA-style if with pl.yield_()
    stream_ << "if ";
    VisitExpr(op->condition_);
    stream_ << ":\n";

    indent_++;
    VisitStmtBody(op->thenBody_, op->returnVars_);
    indent_--;

    if (op->elseBody_.has_value()) {
        stream_ << "\n" << GetIndent() << "else:\n";
        indent_++;
        VisitStmtBody(*op->elseBody_, op->returnVars_);
        indent_--;
    }
}

void IRPrinter::VisitStmt_(const YieldStmtPtr& op)
{
    // Note: In function context, this will be changed to "return" by VisitFunction
    stream_ << prefix_ << ".yield_(";
    for (size_t i = 0; i < op->value_.size(); ++i) {
        if (i > 0)
            stream_ << ", ";
        VisitExpr(op->value_[i]);
    }
    stream_ << ")";
}

void IRPrinter::VisitStmt_(const ReturnStmtPtr& op) { PrintReturnStmtValues(*this, stream_, op->value_); }

void IRPrinter::VisitStmt_(const ForStmtPtr& op)
{
    PrintForRangeHeader(stream_, prefix_, op, [this](const ExprPtr& expr) { VisitExpr(expr); });

    indent_++;
    VisitStmtBody(op->body_, op->returnVars_);
    indent_--;
}

void IRPrinter::VisitStmt_(const WhileStmtPtr& op)
{
    // Check if this is SSA-style (with iter_args) or natural style
    if (op->iterArgs_.empty()) {
        // Natural while loop without iter_args
        stream_ << "while ";
        VisitExpr(op->condition_);
        stream_ << ":\n";

        indent_++;
        VisitStmtBody(op->body_, op->returnVars_);
        indent_--;
    } else {
        // SSA-style while with iter_args - print as explicit DSL syntax
        PrintWhileIterArgsHeader(stream_, prefix_, op, [this](const ExprPtr& expr) { VisitExpr(expr); });

        indent_++;

        // Print condition as pl.cond() call as first body statement
        stream_ << GetIndent() << prefix_ << ".cond(";
        VisitExpr(op->condition_);
        stream_ << ")\n";

        VisitStmtBody(op->body_, op->returnVars_);
        indent_--;
    }
}

void IRPrinter::VisitStmt_(const SeqStmtsPtr& op)
{
    for (size_t i = 0; i < op->stmts_.size(); ++i) {
        PrintStmtBlock(op->stmts_[i]);
        if (i < op->stmts_.size() - 1) {
            stream_ << "\n";
        }
    }
}

void IRPrinter::PrintStmtBlock(const StmtPtr& stmt)
{
    if (auto seq = As<SeqStmts>(stmt)) {
        for (size_t i = 0; i < seq->stmts_.size(); ++i) {
            PrintStmtBlock(seq->stmts_[i]);
            if (i < seq->stmts_.size() - 1)
                stream_ << "\n";
        }
    } else {
        stream_ << GetIndent();
        VisitStmt(stmt);
    }
}

void IRPrinter::VisitStmt_(const EvalStmtPtr& op)
{
    // Print expression statement: expr
    stream_ << prefix_ << ".eval(";
    VisitExpr(op->expr_);
    stream_ << ")";
}

void IRPrinter::VisitStmt_(const BreakStmtPtr& op)
{
    stream_ << "break";
    for (size_t i = 0; i < op->value_.size(); ++i) {
        stream_ << (i == 0 ? " " : ", ");
        VisitExpr(op->value_[i]);
    }
}

void IRPrinter::VisitStmt_(const ContinueStmtPtr& op)
{
    stream_ << "continue";
    for (size_t i = 0; i < op->value_.size(); ++i) {
        stream_ << (i == 0 ? " " : ", ");
        VisitExpr(op->value_[i]);
    }
}

void IRPrinter::VisitStmt_(const TensorOpStmtPtr& op)
{
    // res [, res_token] = opcode(args, tokens=[], attrs=[])
    if (op->result_.size() == 1) {
        VisitExpr(op->result_[0]);
    } else {
        stream_ << "[";
        for (size_t i = 0; i < op->result_.size(); ++i) {
            if (i > 0)
                stream_ << ", ";
            VisitExpr(op->result_[i]);
        }
        stream_ << "]";
    }
    if (op->result_token_) {
        stream_ << ", ";
        VisitExpr(op->result_token_);
    }
    stream_ << " = " << op->opcode_ << "(";
    for (size_t i = 0; i < op->args_.size(); ++i) {
        if (i > 0)
            stream_ << ", ";
        VisitExpr(op->args_[i]);
    }
    if (!op->tokens_.empty()) {
        stream_ << ", tokens=[";
        for (size_t i = 0; i < op->tokens_.size(); ++i) {
            if (i > 0)
                stream_ << ", ";
            VisitExpr(op->tokens_[i]);
        }
        stream_ << "]";
    }

    auto printValue = [&](const std::string& key, const std::any& value) {
        stream_ << key << "=";
        if (value.type() == typeid(int)) {
            stream_ << AnyCast<int>(value);
        } else if (value.type() == typeid(double)) {
            stream_ << FormatFloatLiteral(AnyCast<double>(value));
        } else if (value.type() == typeid(float)) {
            stream_ << FormatFloatLiteral(static_cast<double>(AnyCast<float>(value)));
        } else if (value.type() == typeid(bool)) {
            stream_ << (AnyCast<bool>(value) ? "True" : "False");
        } else if (value.type() == typeid(std::string)) {
            stream_ << std::quoted(AnyCast<std::string>(value));
        } else if (value.type() == typeid(SymbolicScalar)) {
            stream_ << (AnyCast<SymbolicScalar>(value).Dump());
        } else if (value.type() == typeid(std::vector<int>)) {
            auto values = AnyCast<std::vector<int>>(value);
            stream_ << "[";
            for (size_t i = 0; i < values.size(); ++i) {
                if (i > 0)
                    stream_ << ", ";
                stream_ << values[i];
            }
            stream_ << "]";
        } else if (value.type() == typeid(std::vector<SymbolicScalar>)) {
            auto values = AnyCast<std::vector<SymbolicScalar>>(value);
            stream_ << "[";
            for (size_t i = 0; i < values.size(); ++i) {
                if (i > 0)
                    stream_ << ", ";
                stream_ << values[i].Dump();
            }
            stream_ << "]";
        } else {
            INTERNAL_CHECK(false) << "Unsupported function attrs value type: " << DemangleTypeName(value.type().name());
        }
    };
    if (!op->attrs_.empty()) {
        stream_ << ", attrs=[";
        for (size_t i = 0; i < op->attrs_.size(); ++i) {
            if (i > 0)
                stream_ << ", ";
            printValue(op->attrs_[i].first, op->attrs_[i].second);
        }
        stream_ << "]";
    }
    stream_ << ")";
}

void IRPrinter::VisitStmt_(const ScalarOpStmtPtr& op)
{
    // res [, res_token] = opcode(args)
    VisitExpr(op->result_);
    if (op->result_token_) {
        stream_ << ", ";
        VisitExpr(op->result_token_);
    }
    stream_ << " = " << op->opcode_ << "(";
    for (size_t i = 0; i < op->args_.size(); ++i) {
        if (i > 0)
            stream_ << ", ";
        VisitExpr(op->args_[i]);
    }
    stream_ << ")";
}

void IRPrinter::VisitStmt_(const StmtPtr& op) { stream_ << op->TypeName(); }

void IRPrinter::PrintYieldAssignmentVars(const std::vector<VarPtr>& return_vars)
{
    if (return_vars.size() == 1) {
        stream_ << return_vars[0]->name_;
        if (!concise_) {
            stream_ << ": " << Print(return_vars[0]->GetType());
        }
    } else {
        for (size_t i = 0; i < return_vars.size(); ++i) {
            if (i > 0)
                stream_ << ", ";
            stream_ << return_vars[i]->name_;
        }
    }
}

void IRPrinter::VisitStmtBody(const StmtPtr& body, const std::vector<VarPtr>& return_vars)
{
    // Helper to visit statement body and wrap YieldStmt with assignment if needed
    YieldStmtPtr yield_stmt = As<YieldStmt>(body);
    if (yield_stmt) {
        // If parent has return_vars, wrap yield as assignment
        if (!yield_stmt->value_.empty() && !return_vars.empty()) {
            stream_ << GetIndent();
            PrintYieldAssignmentVars(return_vars);
            stream_ << " = " << prefix_ << ".yield_(";
            for (size_t i = 0; i < yield_stmt->value_.size(); ++i) {
                if (i > 0)
                    stream_ << ", ";
                VisitExpr(yield_stmt->value_[i]);
            }
            stream_ << ")";
        } else {
            stream_ << GetIndent();
            VisitStmt(yield_stmt);
        }
    } else if (auto seq_stmts = As<SeqStmts>(body)) {
        // Process each statement in sequence
        if (seq_stmts->stmts_.empty()) {
            stream_ << GetIndent() << "pass";
            return;
        }
        for (size_t i = 0; i < seq_stmts->stmts_.size(); ++i) {
            auto stmt = seq_stmts->stmts_[i];

            // Check if this is the last statement and it's a YieldStmt
            bool is_last = (i == seq_stmts->stmts_.size() - 1);
            yield_stmt = As<YieldStmt>(stmt);
            if (yield_stmt) {
                if (is_last && !yield_stmt->value_.empty() && !return_vars.empty()) {
                    // Wrap as assignment
                    stream_ << GetIndent();
                    PrintYieldAssignmentVars(return_vars);
                    stream_ << " = " << prefix_ << ".yield_(";
                    for (size_t j = 0; j < yield_stmt->value_.size(); ++j) {
                        if (j > 0)
                            stream_ << ", ";
                        VisitExpr(yield_stmt->value_[j]);
                    }
                    stream_ << ")";
                } else {
                    stream_ << GetIndent();
                    VisitStmt(stmt);
                }
            } else {
                PrintStmtBlock(stmt);
            }

            if (i < seq_stmts->stmts_.size() - 1) {
                stream_ << "\n";
            }
        }
    } else {
        PrintStmtBlock(body);
    }
}

void IRPrinter::VisitFunctionBody(const StmtPtr& body)
{
    if (auto seq_stmts = As<SeqStmts>(body)) {
        if (seq_stmts->stmts_.empty()) {
            stream_ << GetIndent() << "pass";
        } else {
            for (size_t i = 0; i < seq_stmts->stmts_.size(); ++i) {
                // Convert yield to return in function context
                if (auto yield_stmt = As<YieldStmt>(seq_stmts->stmts_[i])) {
                    stream_ << GetIndent() << "return";
                    if (!yield_stmt->value_.empty()) {
                        stream_ << " ";
                        for (size_t j = 0; j < yield_stmt->value_.size(); ++j) {
                            if (j > 0)
                                stream_ << ", ";
                            VisitExpr(yield_stmt->value_[j]);
                        }
                    }
                } else {
                    PrintStmtBlock(seq_stmts->stmts_[i]);
                }
                if (i < seq_stmts->stmts_.size() - 1) {
                    stream_ << "\n";
                }
            }
        }
    } else if (auto yield_stmt = As<YieldStmt>(body)) {
        stream_ << GetIndent() << "return";
        if (!yield_stmt->value_.empty()) {
            stream_ << " ";
            for (size_t i = 0; i < yield_stmt->value_.size(); ++i) {
                if (i > 0)
                    stream_ << ", ";
                VisitExpr(yield_stmt->value_[i]);
            }
        }
    } else {
        PrintStmtBlock(body);
    }
}

void IRPrinter::VisitFunction(const FunctionPtr& func)
{
    // Print decorator
    stream_ << GetIndent() << "@" << prefix_ << ".function";
    stream_ << "\n";

    // Print function signature
    stream_ << GetIndent() << "def " << func->name_ << "(";

    // Print parameters with type annotations and direction wrappers
    for (size_t i = 0; i < func->params_.size(); ++i) {
        if (i > 0)
            stream_ << ", ";
        const auto& var = func->params_[i];
        stream_ << var->name_ << ": ";
        stream_ << Print(var->GetType());
    }

    stream_ << ")";

    PrintFunctionReturnAnnotation(stream_, func->returnTypes_, [this](const TypePtr& type) { return Print(type); });

    stream_ << ":\n";

    // Print body - convert yield to return in function context
    indent_++;
    if (func->body_) {
        VisitFunctionBody(func->body_);
    }
    indent_--;
}

void IRPrinter::VisitProgram(const ProgramPtr& program)
{
    stream_ << "# ir.program: " << (program->name_.empty() ? "Program" : program->name_) << "\n";
    bool first = true;
    for (const auto& entry : program->functions_) {
        if (!first) {
            stream_ << "\n"; // Blank line between functions
        }
        first = false;
        VisitFunction(entry.second);
    }
}

std::string IRPrinter::PrintExprForType(const ExprPtr& expr)
{
    if (auto const_int = As<ConstInt>(expr)) {
        return std::to_string(const_int->value_);
    }
    if (auto var = As<Var>(expr)) {
        return var->name_;
    }
    IRPrinter temp_printer(prefix_);
    return temp_printer.Print(expr);
}

void IRPrinter::PrintShapeDims(std::ostringstream& oss, const std::vector<ExprPtr>& shape)
{
    for (size_t i = 0; i < shape.size(); ++i) {
        if (i > 0)
            oss << ", ";
        oss << PrintExprForType(shape[i]);
    }
}

std::string IRPrinter::PrintMemRef(const MemRef& memref)
{
    std::ostringstream oss;
    oss << prefix_ << ".MemRef(";
    oss << prefix_ << ".MemorySpace." << MemorySpaceToString(memref.memorySpace_) << ", ";

    IRPrinter temp_printer(prefix_);
    oss << temp_printer.Print(memref.addr_);
    // Print size
    oss << ", " << memref.size_ << ")";
    return oss.str();
}

std::string PythonPrint(const IRNodePtr& node, const std::string& prefix, bool concise)
{
    IRPrinter printer(prefix, concise);
    return printer.Print(node);
}

std::string PythonPrint(const TypePtr& type, const std::string& prefix)
{
    IRPrinter printer(prefix);
    return printer.Print(type);
}
} // namespace ir
} // namespace pypto