import mlir.dialects.arith as arith
import mlir.dialects.memref as memref
import mlir.extras.types as T
from mlir.dialects.memref import _infer_memref_subview_result_type
from mlir.ir import *
def run(f):
print("\nTEST:", f.__name__)
f()
return f
@run
def testSubViewAccessors():
ctx = Context()
module = Module.parse(
r"""
func.func @f1(%arg0: memref<?x?xf32>) {
%0 = arith.constant 0 : index
%1 = arith.constant 1 : index
%2 = arith.constant 2 : index
%3 = arith.constant 3 : index
%4 = arith.constant 4 : index
%5 = arith.constant 5 : index
memref.subview %arg0[%0, %1][%2, %3][%4, %5] : memref<?x?xf32> to memref<?x?xf32, strided<[?, ?], offset: ?>>
return
}
""",
ctx,
)
func_body = module.body.operations[0].regions[0].blocks[0]
subview = func_body.operations[6]
assert subview.source == subview.operands[0]
assert len(subview.offsets) == 2
assert len(subview.sizes) == 2
assert len(subview.strides) == 2
assert subview.result == subview.results[0]
print(type(subview).__name__)
print(subview.offsets[0])
print(subview.offsets[1])
print(subview.sizes[0])
print(subview.sizes[1])
print(subview.strides[0])
print(subview.strides[1])
@run
def testCustomBuidlers():
with Context() as ctx, Location.unknown(ctx):
module = Module.parse(
r"""
func.func @f1(%arg0: memref<?x?xf32>, %arg1: index, %arg2: index) {
return
}
"""
)
f = module.body.operations[0]
func_body = f.regions[0].blocks[0]
with InsertionPoint.at_block_terminator(func_body):
memref.LoadOp(f.arguments[0], f.arguments[1:])
print(module)
assert module.operation.verify()
@run
def testMemRefAttr():
with Context() as ctx, Location.unknown(ctx):
module = Module.create()
with InsertionPoint(module.body):
memref.global_("objFifo_in0", T.memref(16, T.i32()))
print(module)
@run
def testSubViewOpInferReturnTypeSemantics():
with Context() as ctx, Location.unknown(ctx):
module = Module.create()
with InsertionPoint(module.body):
x = memref.alloc(T.memref(10, 10, T.i32()), [], [])
print(x.owner)
y = memref.subview(x, [1, 1], [3, 3], [1, 1])
assert y.owner.verify()
print(y.owner)
z = memref.subview(
x,
[arith.constant(T.index(), 1), 1],
[3, 3],
[1, 1],
)
print(z.owner)
z = memref.subview(
x,
[arith.constant(T.index(), 3), arith.constant(T.index(), 4)],
[3, 3],
[1, 1],
)
print(z.owner)
s = arith.addi(arith.constant(T.index(), 3), arith.constant(T.index(), 4))
z = memref.subview(
x,
[s, 0],
[3, 3],
[1, 1],
)
print(z)
try:
_infer_memref_subview_result_type(
x.type,
[arith.constant(T.index(), 3), arith.constant(T.index(), 4)],
[ShapedType.get_dynamic_size(), 3],
[1, 1],
)
except ValueError as e:
print(e)
try:
memref.subview(
x,
[arith.constant(T.index(), 3), arith.constant(T.index(), 4)],
[ShapedType.get_dynamic_size(), 3],
[1, 1],
)
except ValueError as e:
print(e)
layout = StridedLayoutAttr.get(ShapedType.get_dynamic_size(), [10, 1])
x = memref.alloc(
T.memref(
10,
10,
T.i32(),
layout=layout,
),
[],
[arith.constant(T.index(), 42)],
)
print(x.owner)
y = memref.subview(
x,
[1, 1],
[3, 3],
[1, 1],
result_type=T.memref(3, 3, T.i32(), layout=layout),
)
print(y.owner)
@run
def testSubViewOpInferReturnTypeExtensiveSlicing():
def check_strides_offset(memref, np_view):
layout = memref.type.layout
dtype_size_in_bytes = np_view.dtype.itemsize
golden_strides = (np.array(np_view.strides) // dtype_size_in_bytes).tolist()
golden_offset = (
np_view.ctypes.data - np_view.base.ctypes.data
) // dtype_size_in_bytes
assert (layout.strides, layout.offset) == (golden_strides, golden_offset)
with Context() as ctx, Location.unknown(ctx):
module = Module.create()
with InsertionPoint(module.body):
shape = (10, 22, 3, 44)
golden_mem = np.zeros(shape, dtype=np.int32)
mem1 = memref.alloc(T.memref(*shape, T.i32()), [], [])
check_strides_offset(memref.subview(mem1, (1, 0, 0, 0), (1, 22, 3, 44), (1, 1, 1, 1)), golden_mem[1:2, ...])
check_strides_offset(memref.subview(mem1, (0, 1, 0, 0), (10, 1, 3, 44), (1, 1, 1, 1)), golden_mem[:, 1:2])
check_strides_offset(memref.subview(mem1, (0, 0, 1, 0), (10, 22, 1, 44), (1, 1, 1, 1)), golden_mem[:, :, 1:2])
check_strides_offset(memref.subview(mem1, (0, 0, 0, 1), (10, 22, 3, 1), (1, 1, 1, 1)), golden_mem[:, :, :, 1:2])
check_strides_offset(memref.subview(mem1, (0, 1, 0, 1), (10, 1, 3, 1), (1, 1, 1, 1)), golden_mem[:, 1:2, :, 1:2])
check_strides_offset(memref.subview(mem1, (1, 0, 0, 1), (1, 22, 3, 1), (1, 1, 1, 1)), golden_mem[1:2, :, :, 1:2])
check_strides_offset(memref.subview(mem1, (1, 1, 0, 0), (1, 1, 3, 44), (1, 1, 1, 1)), golden_mem[1:2, 1:2, :, :])
check_strides_offset(memref.subview(mem1, (0, 0, 1, 1), (10, 22, 1, 1), (1, 1, 1, 1)), golden_mem[:, :, 1:2, 1:2])
check_strides_offset(memref.subview(mem1, (0, 1, 1, 0), (10, 1, 1, 44), (1, 1, 1, 1)), golden_mem[:, 1:2, 1:2, :])
check_strides_offset(memref.subview(mem1, (1, 0, 1, 0), (1, 22, 1, 44), (1, 1, 1, 1)), golden_mem[1:2, :, 1:2, :])
check_strides_offset(memref.subview(mem1, (1, 1, 0, 1), (1, 1, 3, 1), (1, 1, 1, 1)), golden_mem[1:2, 1:2, :, 1:2])
check_strides_offset(memref.subview(mem1, (1, 0, 1, 1), (1, 22, 1, 1), (1, 1, 1, 1)), golden_mem[1:2, :, 1:2, 1:2])
check_strides_offset(memref.subview(mem1, (0, 1, 1, 1), (10, 1, 1, 1), (1, 1, 1, 1)), golden_mem[:, 1:2, 1:2, 1:2])
check_strides_offset(memref.subview(mem1, (1, 1, 1, 0), (1, 1, 1, 44), (1, 1, 1, 1)), golden_mem[1:2, 1:2, 1:2, :])
assert memref.subview(
mem1, (0, 0, 0, 0), (10, 22, 3, 44), (1, 1, 1, 1)
).type.layout == AffineMapAttr.get(
AffineMap.get(
4,
0,
[
AffineDimExpr.get(0),
AffineDimExpr.get(1),
AffineDimExpr.get(2),
AffineDimExpr.get(3),
],
)
)
shape = (7, 22, 30, 44)
golden_mem = np.zeros(shape, dtype=np.int32)
mem2 = memref.alloc(T.memref(*shape, T.i32()), [], [])
check_strides_offset(memref.subview(mem2, (0, 0, 0, 0), (7, 11, 3, 44), (1, 2, 1, 1)), golden_mem[:, 0:22:2])
check_strides_offset(memref.subview(mem2, (0, 0, 0, 0), (7, 11, 11, 44), (1, 2, 30, 1)), golden_mem[:, 0:22:2, 0:330:30])
check_strides_offset(memref.subview(mem2, (0, 0, 0, 0), (7, 11, 11, 11), (1, 2, 30, 400)), golden_mem[:, 0:22:2, 0:330:30, 0:4400:400])
shape = (8, 8)
golden_mem = np.zeros(shape, dtype=np.int32)
mem3 = memref.alloc(T.memref(*shape, T.i32()), [], [])
check_strides_offset(memref.subview(mem3, (0, 0), (4, 4), (1, 1)), golden_mem[0:4, 0:4])
check_strides_offset(memref.subview(mem3, (4, 4), (4, 4), (1, 1)), golden_mem[4:8, 4:8])