import gc
import io
import itertools
from mlir.ir import *
from mlir.dialects.builtin import ModuleOp
from mlir.dialects import arith
from mlir.dialects._ods_common import _cext
def run(f):
print("\nTEST:", f.__name__)
f()
gc.collect()
assert Context._get_live_count() == 0
return f
def expect_index_error(callback):
try:
_ = callback()
raise RuntimeError("Expected IndexError")
except IndexError:
pass
@run
def testTraverseOpRegionBlockIterators():
ctx = Context()
ctx.allow_unregistered_dialects = True
module = Module.parse(
r"""
func.func @f1(%arg0: i32) -> i32 {
%1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32
return %1 : i32
}
""",
ctx,
)
op = module.operation
assert op.context is ctx
regions = list(op.regions)
blocks = list(regions[0].blocks)
print(f"MODULE REGIONS={len(regions)} BLOCKS={len(blocks)}")
print(f".verify = {module.operation.verify()}")
default_blocks = list(regions[0])
assert default_blocks == blocks
operations = list(blocks[0].operations)
default_operations = list(blocks[0])
assert default_operations == operations
def walk_operations(indent, op):
for i, region in enumerate(op.regions):
print(f"{indent}REGION {i}:")
for j, block in enumerate(region):
print(f"{indent} BLOCK {j}:")
for k, child_op in enumerate(block):
print(f"{indent} OP {k}: {child_op}")
walk_operations(indent + " ", child_op)
walk_operations("", op)
print(" Region iter:", iter(op.regions))
print(" Block iter:", iter(op.regions[0]))
print("Operation iter:", iter(op.regions[0].blocks[0]))
@run
def testTraverseOpRegionBlockIndices():
ctx = Context()
ctx.allow_unregistered_dialects = True
module = Module.parse(
r"""
func.func @f1(%arg0: i32) -> i32 {
%1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32
return %1 : i32
}
""",
ctx,
)
def walk_operations(indent, op):
for i in range(len(op.regions)):
region = op.regions[i]
print(f"{indent}REGION {i}:")
for j in range(len(region.blocks)):
block = region.blocks[j]
print(f"{indent} BLOCK {j}:")
for k in range(len(block.operations)):
child_op = block.operations[k]
print(f"{indent} OP {k}: {child_op}")
print(
f"{indent} OP {k}: parent {child_op.operation.parent.name}"
)
walk_operations(indent + " ", child_op)
walk_operations("", module.operation)
@run
def testBlockAndRegionOwners():
ctx = Context()
ctx.allow_unregistered_dialects = True
module = Module.parse(
r"""
builtin.module {
func.func @f() {
func.return
}
}
""",
ctx,
)
assert module.operation.regions[0].owner == module.operation
assert module.operation.regions[0].blocks[0].owner == module.operation
func = module.body.operations[0]
assert func.operation.regions[0].owner == func
assert func.operation.regions[0].blocks[0].owner == func
@run
def testBlockArgumentList():
with Context() as ctx:
module = Module.parse(
r"""
func.func @f1(%arg0: i32, %arg1: f64, %arg2: index) {
return
}
""",
ctx,
)
func = module.body.operations[0]
entry_block = func.regions[0].blocks[0]
assert len(entry_block.arguments) == 3
for arg in entry_block.arguments:
print(f"Argument {arg.arg_number}, type {arg.type}")
new_type = IntegerType.get_signless(8 * (arg.arg_number + 1))
arg.set_type(new_type)
for arg in entry_block.arguments:
print(f"Argument {arg.arg_number}, type {arg.type}")
for arg in entry_block.arguments[1:]:
print(f"Argument {arg.arg_number}, type {arg.type}")
print("Length: ", len(entry_block.arguments[:2] + entry_block.arguments[1:]))
for t in entry_block.arguments.types:
print("Type: ", t)
for t in entry_block.arguments[1:].types:
print("Sliced type: ", t)
restructured = entry_block.arguments[-1:] + entry_block.arguments[:1]
for arg in restructured:
print(f"Argument {arg.arg_number}, type {arg.type}")
@run
def testOperationOperands():
with Context() as ctx:
ctx.allow_unregistered_dialects = True
module = Module.parse(
r"""
func.func @f1(%arg0: i32) {
%0 = "test.producer"() : () -> i64
"test.consumer"(%arg0, %0) : (i32, i64) -> ()
return
}"""
)
func = module.body.operations[0]
entry_block = func.regions[0].blocks[0]
consumer = entry_block.operations[1]
assert len(consumer.operands) == 2
for i, operand in enumerate(consumer.operands):
print(f"Operand {i}, type {operand.type}")
@run
def testOperationOperandsSlice():
with Context() as ctx:
ctx.allow_unregistered_dialects = True
module = Module.parse(
r"""
func.func @f1() {
%0 = "test.producer0"() : () -> i64
%1 = "test.producer1"() : () -> i64
%2 = "test.producer2"() : () -> i64
%3 = "test.producer3"() : () -> i64
%4 = "test.producer4"() : () -> i64
"test.consumer"(%0, %1, %2, %3, %4) : (i64, i64, i64, i64, i64) -> ()
return
}"""
)
func = module.body.operations[0]
entry_block = func.regions[0].blocks[0]
consumer = entry_block.operations[5]
assert len(consumer.operands) == 5
for left, right in zip(consumer.operands, consumer.operands[::-1][::-1]):
assert left == right
full_slice = consumer.operands[:]
for operand in full_slice:
print(operand)
first_two = consumer.operands[0:2]
for operand in first_two:
print(operand)
last_two = consumer.operands[3:]
for operand in last_two:
print(operand)
even = consumer.operands[::2]
for operand in even:
print(operand)
fourth = consumer.operands[::2][1::2]
for operand in fourth:
print(operand)
@run
def testOperationOperandsSet():
with Context() as ctx, Location.unknown(ctx):
ctx.allow_unregistered_dialects = True
module = Module.parse(
r"""
func.func @f1() {
%0 = "test.producer0"() : () -> i64
%1 = "test.producer1"() : () -> i64
%2 = "test.producer2"() : () -> i64
"test.consumer"(%0) : (i64) -> ()
return
}"""
)
func = module.body.operations[0]
entry_block = func.regions[0].blocks[0]
producer1 = entry_block.operations[1]
producer2 = entry_block.operations[2]
consumer = entry_block.operations[3]
assert len(consumer.operands) == 1
type = consumer.operands[0].type
consumer.operands[0] = producer1.result
print(consumer.operands[0])
consumer.operands[-1] = producer2.result
print(consumer.operands[0])
@run
def testDetachedOperation():
ctx = Context()
ctx.allow_unregistered_dialects = True
with Location.unknown(ctx):
i32 = IntegerType.get_signed(32)
op1 = Operation.create(
"custom.op1",
results=[i32, i32],
regions=1,
attributes={
"foo": StringAttr.get("foo_value"),
"bar": StringAttr.get("bar_value"),
},
)
print(op1)
@run
def testOperationInsertionPoint():
ctx = Context()
ctx.allow_unregistered_dialects = True
module = Module.parse(
r"""
func.func @f1(%arg0: i32) -> i32 {
%1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32
return %1 : i32
}
""",
ctx,
)
with Location.unknown(ctx):
op1 = Operation.create("custom.op1")
op2 = Operation.create("custom.op2")
func = module.body.operations[0]
entry_block = func.regions[0].blocks[0]
ip = InsertionPoint.at_block_begin(entry_block)
ip.insert(op1)
ip.insert(op2)
print(module)
try:
ip.insert(op1)
except ValueError:
pass
else:
assert False, "expected insert of attached op to raise"
@run
def testOperationWithRegion():
ctx = Context()
ctx.allow_unregistered_dialects = True
with Location.unknown(ctx):
i32 = IntegerType.get_signed(32)
op1 = Operation.create("custom.op1", regions=1)
block = op1.regions[0].blocks.append(i32, i32)
terminator = Operation.create("custom.terminator")
ip = InsertionPoint(block)
ip.insert(terminator)
print(op1)
module = Module.parse(
r"""
func.func @f1(%arg0: i32) -> i32 {
%1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32
return %1 : i32
}
"""
)
func = module.body.operations[0]
entry_block = func.regions[0].blocks[0]
ip = InsertionPoint.at_block_begin(entry_block)
ip.insert(op1)
print(module)
@run
def testOperationResultList():
ctx = Context()
module = Module.parse(
r"""
func.func @f1() {
%0:3 = call @f2() : () -> (i32, f64, index)
call @f3() : () -> ()
return
}
func.func private @f2() -> (i32, f64, index)
func.func private @f3() -> ()
""",
ctx,
)
caller = module.body.operations[0]
call = caller.regions[0].blocks[0].operations[0]
assert len(call.results) == 3
for res in call.results:
print(f"Result {res.result_number}, type {res.type}")
for t in call.results.types:
print(f"Result type {t}")
expect_index_error(lambda: call.results[3])
expect_index_error(lambda: call.results[-4])
no_results_call = caller.regions[0].blocks[0].operations[1]
assert len(no_results_call.results) == 0
assert no_results_call.results.owner == no_results_call
@run
def testOperationResultListSlice():
with Context() as ctx:
ctx.allow_unregistered_dialects = True
module = Module.parse(
r"""
func.func @f1() {
"some.op"() : () -> (i1, i2, i3, i4, i5)
return
}
"""
)
func = module.body.operations[0]
entry_block = func.regions[0].blocks[0]
producer = entry_block.operations[0]
assert len(producer.results) == 5
for left, right in zip(producer.results, producer.results[::-1][::-1]):
assert left == right
assert left.result_number == right.result_number
full_slice = producer.results[:]
for res in full_slice:
print(f"Result {res.result_number}, type {res.type}")
middle = producer.results[1:4]
for res in middle:
print(f"Result {res.result_number}, type {res.type}")
odd = producer.results[1::2]
for res in odd:
print(f"Result {res.result_number}, type {res.type}")
inverted_middle = producer.results[-2:0:-2]
for res in inverted_middle:
print(f"Result {res.result_number}, type {res.type}")
@run
def testOperationAttributes():
ctx = Context()
ctx.allow_unregistered_dialects = True
module = Module.parse(
r"""
"some.op"() { some.attribute = 1 : i8,
other.attribute = 3.0,
dependent = "text" } : () -> ()
""",
ctx,
)
op = module.body.operations[0]
assert len(op.attributes) == 3
iattr = op.attributes["some.attribute"]
fattr = op.attributes["other.attribute"]
sattr = op.attributes["dependent"]
print(f"Attribute type {iattr.type}, value {iattr.value}")
print(f"Attribute type {fattr.type}, value {fattr.value}")
print(f"Attribute value {sattr.value}")
print(f"Attribute value {sattr.value_bytes}")
for attr in op.attributes:
print(str(attr))
try:
op.attributes["does_not_exist"]
except KeyError:
pass
else:
assert False, "expected KeyError on accessing a non-existent attribute"
try:
op.attributes[42]
except IndexError:
pass
else:
assert False, "expected IndexError on accessing an out-of-bounds attribute"
@run
def testOperationPrint():
ctx = Context()
module = Module.parse(
r"""
func.func @f1(%arg0: i32) -> i32 {
%0 = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32> loc("nom")
return %arg0 : i32
}
""",
ctx,
)
module.operation.print()
f = io.StringIO()
module.operation.print(file=f)
str_value = f.getvalue()
print(str_value.__class__)
print(f.getvalue())
bytecode_stream = io.BytesIO()
module.operation.write_bytecode(bytecode_stream, desired_version=1)
bytecode = bytecode_stream.getvalue()
assert bytecode.startswith(b"ML\xefR"), "Expected bytecode to start with MLïR"
module_roundtrip = Module.parse(bytecode, ctx)
f = io.StringIO()
module_roundtrip.operation.print(file=f)
roundtrip_value = f.getvalue()
assert str_value == roundtrip_value, "Mismatch after roundtrip bytecode"
f = io.BytesIO()
module.operation.print(file=f, binary=True)
bytes_value = f.getvalue()
print(bytes_value.__class__)
print(bytes_value)
module.operation.print(enable_debug_info=True, use_local_scope=True)
state = AsmState(module.operation)
module.operation.print(state)
module.operation.print(
large_elements_limit=2,
enable_debug_info=True,
pretty_debug_info=True,
print_generic_op_form=True,
use_local_scope=True,
)
module.body.operations[0].print(
skip_regions=True,
)
@run
def testKnownOpView():
with Context(), Location.unknown():
Context.current.allow_unregistered_dialects = True
module = Module.parse(
r"""
%1 = "custom.f32"() : () -> f32
%2 = "custom.f32"() : () -> f32
%3 = arith.addf %1, %2 : f32
%4 = arith.constant 0 : i32
"""
)
print(module)
addf = module.body.operations[2]
print(repr(addf))
print(addf.lhs)
custom = module.body.operations[0]
print(repr(custom))
custom = module.body.operations[0]
print(repr(custom))
constant = module.body.operations[3]
print(repr(constant))
print("literal value", constant.literal_value)
@_cext.register_operation(arith._Dialect, replace=True)
class ConstantOp(arith.ConstantOp):
def __init__(self, result, value, *, loc=None, ip=None):
if isinstance(value, int):
super().__init__(IntegerAttr.get(result, value), loc=loc, ip=ip)
elif isinstance(value, float):
super().__init__(FloatAttr.get(result, value), loc=loc, ip=ip)
else:
super().__init__(value, loc=loc, ip=ip)
constant = module.body.operations[3]
print(repr(constant))
@run
def testSingleResultProperty():
with Context(), Location.unknown():
Context.current.allow_unregistered_dialects = True
module = Module.parse(
r"""
"custom.no_result"() : () -> ()
%0:2 = "custom.two_result"() : () -> (f32, f32)
%1 = "custom.one_result"() : () -> f32
"""
)
print(module)
try:
module.body.operations[0].result
except ValueError as e:
print(e)
else:
assert False, "Expected exception"
try:
module.body.operations[1].result
except ValueError as e:
print(e)
else:
assert False, "Expected exception"
print(module.body.operations[2])
def create_invalid_operation():
op = Operation.create("builtin.module", regions=2)
op.regions[0].blocks.append()
return op
@run
def testInvalidOperationStrSoftFails():
ctx = Context()
with Location.unknown(ctx):
invalid_op = create_invalid_operation()
print(invalid_op)
try:
invalid_op.verify()
except MLIRError as e:
print(f"Exception: <{e}>")
@run
def testInvalidModuleStrSoftFails():
ctx = Context()
with Location.unknown(ctx):
module = Module.create()
with InsertionPoint(module.body):
invalid_op = create_invalid_operation()
print(module)
@run
def testInvalidOperationGetAsmBinarySoftFails():
ctx = Context()
with Location.unknown(ctx):
invalid_op = create_invalid_operation()
print(invalid_op.get_asm(binary=True))
@run
def testCreateWithInvalidAttributes():
ctx = Context()
with Location.unknown(ctx):
try:
Operation.create(
"builtin.module", attributes={None: StringAttr.get("name")}
)
except Exception as e:
print(e)
try:
Operation.create("builtin.module", attributes={42: StringAttr.get("name")})
except Exception as e:
print(e)
try:
Operation.create("builtin.module", attributes={"some_key": ctx})
except Exception as e:
print(e)
try:
Operation.create("builtin.module", attributes={"some_key": None})
except Exception as e:
print(e)
@run
def testOperationName():
ctx = Context()
ctx.allow_unregistered_dialects = True
module = Module.parse(
r"""
%0 = "custom.op1"() : () -> f32
%1 = "custom.op2"() : () -> i32
%2 = "custom.op1"() : () -> f32
""",
ctx,
)
for op in module.body.operations:
print(op.operation.name)
@run
def testCapsuleConversions():
ctx = Context()
ctx.allow_unregistered_dialects = True
with Location.unknown(ctx):
m = Operation.create("custom.op1").operation
m_capsule = m._CAPIPtr
assert '"mlir.ir.Operation._CAPIPtr"' in repr(m_capsule)
m2 = Operation._CAPICreate(m_capsule)
assert m2 is m
@run
def testOperationErase():
ctx = Context()
ctx.allow_unregistered_dialects = True
with Location.unknown(ctx):
m = Module.create()
with InsertionPoint(m.body):
op = Operation.create("custom.op1")
print(m)
op.operation.erase()
print(m)
Operation.create("custom.op2")
@run
def testOperationClone():
ctx = Context()
ctx.allow_unregistered_dialects = True
with Location.unknown(ctx):
m = Module.create()
with InsertionPoint(m.body):
op = Operation.create("custom.op1")
print(m)
clone = op.operation.clone()
op.operation.erase()
print(m)
@run
def testOperationLoc():
ctx = Context()
ctx.allow_unregistered_dialects = True
with ctx:
loc = Location.name("loc")
op = Operation.create("custom.op", loc=loc)
assert op.location == loc
assert op.operation.location == loc
@run
def testModuleMerge():
with Context():
m1 = Module.parse("func.func private @foo()")
m2 = Module.parse(
"""
func.func private @bar()
func.func private @qux()
"""
)
foo = m1.body.operations[0]
bar = m2.body.operations[0]
qux = m2.body.operations[1]
bar.move_before(foo)
qux.move_after(foo)
print(m1)
print(m2)
@run
def testAppendMoveFromAnotherBlock():
with Context():
m1 = Module.parse("func.func private @foo()")
m2 = Module.parse("func.func private @bar()")
func = m1.body.operations[0]
m2.body.append(func)
print(m2)
print(m1)
@run
def testDetachFromParent():
with Context():
m1 = Module.parse("func.func private @foo()")
func = m1.body.operations[0].detach_from_parent()
try:
func.detach_from_parent()
except ValueError as e:
if "has no parent" not in str(e):
raise
else:
assert False, "expected ValueError when detaching a detached operation"
print(m1)
@run
def testOperationHash():
ctx = Context()
ctx.allow_unregistered_dialects = True
with ctx, Location.unknown():
op = Operation.create("custom.op1")
assert hash(op) == hash(op.operation)
@run
def testOperationParse():
with Context() as ctx:
ctx.allow_unregistered_dialects = True
m = Operation.parse("module {}")
o = Operation.parse('"test.foo"() : () -> ()')
assert isinstance(m, ModuleOp)
assert type(o) is OpView
m = ModuleOp.parse("module {}")
assert isinstance(m, ModuleOp)
try:
ModuleOp.parse('"test.foo"() : () -> ()')
except MLIRError as e:
print(f"error: {e}")
else:
assert False, "expected error"
o = Operation.parse('"test.foo"() : () -> ()', source_name="my-source-string")
print(
f"op_with_source_name: {o.get_asm(enable_debug_info=True, use_local_scope=True)}"
)
@run
def testOpWalk():
ctx = Context()
ctx.allow_unregistered_dialects = True
module = Module.parse(
r"""
builtin.module {
func.func @f() {
func.return
}
}
""",
ctx,
)
def callback(op):
print(op.name)
return WalkResult.ADVANCE
print("Post-order")
module.operation.walk(callback)
print("Pre-order")
module.operation.walk(callback, WalkOrder.PRE_ORDER)
print("Interrupt post-order")
def callback(op):
print(op.name)
return WalkResult.INTERRUPT
module.operation.walk(callback)
print("Skip pre-order")
def callback(op):
print(op.name)
return WalkResult.SKIP
module.operation.walk(callback, WalkOrder.PRE_ORDER)
print("Exception")
def callback(op):
print(op.name)
raise ValueError
return WalkResult.ADVANCE
try:
module.operation.walk(callback)
except RuntimeError:
print("Exception raised")