import gc
from mlir.ir import *
def run(f):
print("\nTEST:", f.__name__)
f()
gc.collect()
assert Context._get_live_count() == 0
def test_insert_at_block_end():
ctx = Context()
ctx.allow_unregistered_dialects = True
with Location.unknown(ctx):
module = Module.parse(
r"""
func.func @foo() -> () {
"custom.op1"() : () -> ()
}
"""
)
entry_block = module.body.operations[0].regions[0].blocks[0]
ip = InsertionPoint(entry_block)
assert ip.block == entry_block
assert ip.ref_operation is None
ip.insert(Operation.create("custom.op2"))
module.operation.print()
run(test_insert_at_block_end)
def test_insert_before_operation():
ctx = Context()
ctx.allow_unregistered_dialects = True
with Location.unknown(ctx):
module = Module.parse(
r"""
func.func @foo() -> () {
"custom.op1"() : () -> ()
"custom.op2"() : () -> ()
}
"""
)
entry_block = module.body.operations[0].regions[0].blocks[0]
ip = InsertionPoint(entry_block.operations[1])
assert ip.block == entry_block
assert ip.ref_operation == entry_block.operations[1]
ip.insert(Operation.create("custom.op3"))
module.operation.print()
run(test_insert_before_operation)
def test_insert_at_block_begin():
ctx = Context()
ctx.allow_unregistered_dialects = True
with Location.unknown(ctx):
module = Module.parse(
r"""
func.func @foo() -> () {
"custom.op2"() : () -> ()
}
"""
)
entry_block = module.body.operations[0].regions[0].blocks[0]
ip = InsertionPoint.at_block_begin(entry_block)
assert ip.block == entry_block
assert ip.ref_operation == entry_block.operations[0]
ip.insert(Operation.create("custom.op1"))
module.operation.print()
run(test_insert_at_block_begin)
def test_insert_at_block_begin_empty():
pass
run(test_insert_at_block_begin_empty)
def test_insert_at_terminator():
ctx = Context()
ctx.allow_unregistered_dialects = True
with Location.unknown(ctx):
module = Module.parse(
r"""
func.func @foo() -> () {
"custom.op1"() : () -> ()
return
}
"""
)
entry_block = module.body.operations[0].regions[0].blocks[0]
ip = InsertionPoint.at_block_terminator(entry_block)
assert ip.block == entry_block
assert ip.ref_operation == entry_block.operations[1]
ip.insert(Operation.create("custom.op2"))
module.operation.print()
run(test_insert_at_terminator)
def test_insert_at_block_terminator_missing():
ctx = Context()
ctx.allow_unregistered_dialects = True
with ctx:
module = Module.parse(
r"""
func.func @foo() -> () {
"custom.op1"() : () -> ()
}
"""
)
entry_block = module.body.operations[0].regions[0].blocks[0]
try:
ip = InsertionPoint.at_block_terminator(entry_block)
except ValueError as e:
print(e)
else:
assert False, "Expected exception"
run(test_insert_at_block_terminator_missing)
def test_insert_at_end_with_terminator_errors():
with Context() as ctx, Location.unknown():
ctx.allow_unregistered_dialects = True
module = Module.parse(
r"""
func.func @foo() -> () {
return
}
"""
)
entry_block = module.body.operations[0].regions[0].blocks[0]
with InsertionPoint(entry_block):
try:
Operation.create("custom.op1", results=[], operands=[])
except IndexError as e:
print(f"ERROR: {e}")
run(test_insert_at_end_with_terminator_errors)
def test_insertion_point_context():
ctx = Context()
ctx.allow_unregistered_dialects = True
with Location.unknown(ctx):
module = Module.parse(
r"""
func.func @foo() -> () {
"custom.op1"() : () -> ()
}
"""
)
entry_block = module.body.operations[0].regions[0].blocks[0]
with InsertionPoint(entry_block):
Operation.create("custom.op2")
with InsertionPoint.at_block_begin(entry_block):
Operation.create("custom.opa")
Operation.create("custom.opb")
Operation.create("custom.op3")
module.operation.print()
run(test_insertion_point_context)