import gc
from mlir.ir import *
def run(f):
print("\nTEST:", f.__name__)
f()
gc.collect()
assert Context._get_live_count() == 0
return f
@run
def testAffineMapCapsule():
with Context() as ctx:
am1 = AffineMap.get_empty(ctx)
affine_map_capsule = am1._CAPIPtr
print(affine_map_capsule)
am2 = AffineMap._CAPICreate(affine_map_capsule)
assert am2 == am1
assert am2.context is ctx
@run
def testAffineMapGet():
with Context() as ctx:
d0 = AffineDimExpr.get(0)
d1 = AffineDimExpr.get(1)
c2 = AffineConstantExpr.get(2)
map0 = AffineMap.get(2, 3, [])
print(map0)
map1 = AffineMap.get(2, 3, [d1, c2])
print(map1)
map2 = AffineMap.get(0, 0, [c2])
print(map2)
map3 = AffineMap.get(2, 0, [d0, d1])
print(map3)
map4 = AffineMap.get(2, 0, [d1])
print(map4)
map5 = AffineMap.get_permutation([2, 0, 1])
print(map5)
assert map1 == AffineMap.get(2, 3, [d1, c2])
assert AffineMap.get(0, 0, []) == AffineMap.get_empty()
assert map2 == AffineMap.get_constant(2)
assert map3 == AffineMap.get_identity(2)
assert map4 == AffineMap.get_minor_identity(2, 1)
try:
AffineMap.get(1, 1, [1])
except RuntimeError as e:
print(e)
try:
AffineMap.get(1, 1, [None])
except RuntimeError as e:
print(e)
try:
AffineMap.get_permutation([1, 0, 1])
except RuntimeError as e:
print(e)
try:
map3.get_submap([42])
except ValueError as e:
print(e)
try:
map3.get_minor_submap(42)
except ValueError as e:
print(e)
try:
map3.get_major_submap(42)
except ValueError as e:
print(e)
@run
def testAffineMapDerive():
with Context() as ctx:
map5 = AffineMap.get_identity(5)
map123 = map5.get_submap([1, 2, 3])
print(map123)
map01 = map5.get_major_submap(2)
print(map01)
map34 = map5.get_minor_submap(2)
print(map34)
@run
def testAffineMapProperties():
with Context():
d0 = AffineDimExpr.get(0)
d1 = AffineDimExpr.get(1)
d2 = AffineDimExpr.get(2)
map1 = AffineMap.get(3, 0, [d2, d0])
map2 = AffineMap.get(3, 0, [d2, d0, d1])
map3 = AffineMap.get(3, 1, [d2, d0, d1])
print(map1.is_permutation)
print(map1.is_projected_permutation)
print(map2.is_permutation)
print(map2.is_projected_permutation)
print(map3.is_permutation)
print(map3.is_projected_permutation)
@run
def testAffineMapExprs():
with Context():
d0 = AffineDimExpr.get(0)
d1 = AffineDimExpr.get(1)
d2 = AffineDimExpr.get(2)
map3 = AffineMap.get(3, 1, [d2, d0, d1])
print(map3.n_dims)
print(map3.n_inputs)
print(map3.n_symbols)
assert map3.n_inputs == map3.n_dims + map3.n_symbols
print(len(map3.results))
for expr in map3.results:
print(expr)
for expr in map3.results[-1:-4:-1]:
print(expr)
assert list(map3.results) == [d2, d0, d1]
@run
def testCompressUnusedSymbols():
with Context() as ctx:
d0, d1, d2 = (AffineDimExpr.get(0), AffineDimExpr.get(1), AffineDimExpr.get(2))
s0, s1, s2 = (
AffineSymbolExpr.get(0),
AffineSymbolExpr.get(1),
AffineSymbolExpr.get(2),
)
maps = [
AffineMap.get(3, 3, [d2, d0, d1]),
AffineMap.get(3, 3, [d2, d0 + s2, d1]),
AffineMap.get(3, 3, [d1, d2, d0]),
]
compressed_maps = AffineMap.compress_unused_symbols(maps, ctx)
print(maps)
print(compressed_maps)
@run
def testReplace():
with Context() as ctx:
d0, d1, d2 = (AffineDimExpr.get(0), AffineDimExpr.get(1), AffineDimExpr.get(2))
s0, s1, s2 = (
AffineSymbolExpr.get(0),
AffineSymbolExpr.get(1),
AffineSymbolExpr.get(2),
)
map1 = AffineMap.get(3, 3, [d2, d0 + s1 + s2, d1 + s0])
replace0 = map1.replace(s0, AffineConstantExpr.get(42), 3, 3)
replace1 = map1.replace(s1, AffineConstantExpr.get(42), 3, 3)
replace3 = map1.replace(s2, AffineConstantExpr.get(42), 3, 2)
print(replace0)
print(replace1)
print(replace3)
@run
def testHash():
with Context():
d0, d1 = AffineDimExpr.get(0), AffineDimExpr.get(1)
m1 = AffineMap.get(2, 0, [d0, d1])
m2 = AffineMap.get(2, 0, [d1, d0])
assert hash(m1) == hash(AffineMap.get(2, 0, [d0, d1]))
dictionary = dict()
dictionary[m1] = 1
dictionary[m2] = 2
assert m1 in dictionary