import pypto
from pypto import pil, ir, logging
def _compile_ops(func, *args):
return pil.compile(func, *args)
def _tensor_ops_of(ops):
return [op for op in ops if isinstance(op, ir.TensorOpStmt)]
def test_tensor_binary_ops():
def f(x, y):
pypto.set_vec_tile_shapes(16, 16)
z = x / y
z = x ** y
x = pypto.Tensor(shape=(4, 4), dtype=pypto.DT_FP32)
y = pypto.Tensor(shape=(4, 4), dtype=pypto.DT_FP32)
func = _compile_ops(f, x, y)
ts = _tensor_ops_of(func.body)
opcodes = {s.opcode for s in ts}
assert 'DIV' in opcodes
assert 'POW' in opcodes
def test_scalar_binary_bitwise_ops():
def f(x, y):
z = x | y
z = x ^ y
z = x | y
z = x ^ y
z = x & y
z = x << y
z = x >> y
func = _compile_ops(f, 1, 2)
ts = _tensor_ops_of(func.body)
assert not ts
def test_tensor_matmul_op():
def f(x, y):
pypto.set_cube_tile_shapes([16, 16], [16, 16], [16, 16])
z = x @ y
x = pypto.Tensor(shape=(32, 32), dtype=pypto.DT_FP32)
y = pypto.Tensor(shape=(32, 32), dtype=pypto.DT_FP32)
func = _compile_ops(f, x, y)
ts = _tensor_ops_of(func.body)
opcodes = {s.opcode for s in ts}
assert 'A_MUL_B' in opcodes
def test_tensor_unary_ops():
def f(x):
pypto.set_vec_tile_shapes(16, 16)
z = -x
z = +x
x = pypto.Tensor(shape=(4, 4), dtype=pypto.DT_INT32)
func = _compile_ops(f, x)
ts = _tensor_ops_of(func.body)
opcodes = {s.opcode for s in ts}
assert 'MULS' in opcodes
def test_pypto2ir():
def f(x, y):
pypto.set_vec_tile_shapes(16, 16)
z = pypto.Tensor((x.shape[0], 64), dtype=pypto.DT_FP32, name="z")
i = 0
tx = pypto.view(x, [16, 16], [i * 32, 0])
tx2 = pypto.add(tx, tx)
pypto.assemble(tx2, [i * 32, 0], z)
tz = pypto.view(z, [16, 16], [i * 32, 0])
tz2 = pypto.add(tz, tz)
pypto.assemble(tz2, [i * 32, 0], y)
min(1, 2, 3, 4, 5)
x = pypto.Tensor(shape=(-1, 64), dtype=pypto.DT_FP32, name="x")
y = pypto.Tensor(shape=(-1, 64), dtype=pypto.DT_FP32, name="y")
func = pil.compile(f, x, y)
def test_ir_range():
results = []
def foo(n):
total = 0
for i in range(n):
if i < 3:
if i == 0:
results.append(("nested_if", i))
else:
results.append(("nested_else", i))
total += i
if i == 5:
for j in range(4):
if j == 1:
continue
if j == 2:
break
results.append(("inner", j))
total += j
if i == 7:
break
total += i
results.append(("final", i, total))
foo(10)
expected, results = results, []
pil.compile(foo, 10)
assert results == expected
def test_ir_loop():
def foo(n):
ans = 0
for i in pypto.loop(n):
ans += i
ans = ans + 10
pil.compile(foo, 10)
def _for_ops_of(ops):
return [op for op in ops if isinstance(op, ir.ForStmt)]
def test_ir_loop_basic():
"""Carry variables: ans and loop var i."""
def foo(n):
ans = 0
for i in pypto.loop(n):
ans += i
func = pil.compile(foo, 10)
for_stmts = _for_ops_of(func.body)
assert len(for_stmts) == 1
f = for_stmts[0]
assert len(f.iter_args) == 2
assert len(f.return_vars) == 2
assert isinstance(f.start, ir.ConstInt) and f.start.value == 0
assert isinstance(f.stop, ir.ConstInt) and f.stop.value == 10
assert isinstance(f.step, ir.ConstInt) and f.step.value == 1
def test_ir_loop_no_carry():
"""Loop var i is always carried."""
def foo(n):
for _ in pypto.loop(n):
pass
func = pil.compile(foo, 5)
for_stmts = _for_ops_of(func.body)
assert len(for_stmts) == 1
f = for_stmts[0]
assert len(f.iter_args) == 1
assert len(f.return_vars) == 1
def test_ir_loop_two_carries():
"""Loop carrying ans, count, and loop var i."""
def foo(n):
ans = 0
count = 1
for i in pypto.loop(n):
ans += i
count += 1
func = pil.compile(foo, 10)
for_stmts = _for_ops_of(func.body)
assert len(for_stmts) == 1
f = for_stmts[0]
assert len(f.iter_args) == 3
assert len(f.return_vars) == 3
def test_ir_loop_range_two_args():
"""pypto.loop(start, stop) form."""
def foo():
ans = 0
for i in pypto.loop(2, 10):
ans += i
func = pil.compile(foo)
for_stmts = _for_ops_of(func.body)
assert len(for_stmts) == 1
f = for_stmts[0]
assert isinstance(f.start, ir.ConstInt) and f.start.value == 2
assert isinstance(f.stop, ir.ConstInt) and f.stop.value == 10
assert isinstance(f.step, ir.ConstInt) and f.step.value == 1
assert len(f.iter_args) == 2
def test_ir_loop_range_three_args():
"""pypto.loop(start, stop, step) form."""
def foo():
ans = 0
for i in pypto.loop(1, 20, 3):
ans += i
func = pil.compile(foo)
for_stmts = _for_ops_of(func.body)
assert len(for_stmts) == 1
f = for_stmts[0]
assert isinstance(f.start, ir.ConstInt) and f.start.value == 1
assert isinstance(f.stop, ir.ConstInt) and f.stop.value == 20
assert isinstance(f.step, ir.ConstInt) and f.step.value == 3
assert len(f.iter_args) == 2
def test_ir_loop_sequential():
"""Two sequential loops, both carry ans and their loop var."""
def foo(n):
ans = 0
for i in pypto.loop(n):
ans += i
for j in pypto.loop(n):
ans += j
func = pil.compile(foo, 5)
for_stmts = _for_ops_of(func.body)
assert len(for_stmts) == 2
for f in for_stmts:
assert len(f.iter_args) == 2
assert len(f.return_vars) == 2
def test_ir_loop_nested():
"""Nested loops: inner ForStmt is inside outer's body."""
def foo(n, m):
ans = 0
for i in pypto.loop(n):
for j in pypto.loop(m):
if j % 2 == 0:
ans += i + j
else:
ans += i - j
ans += n
pil.compile(foo, 4, 5)
def test_ir_loop_nested1():
"""Nested loops: inner ForStmt is inside outer's body."""
def foo(n, m):
ans = 0
for i in pypto.loop(n):
for j in pypto.loop(m):
if j % 2 == 0:
ans += i + j
break
else:
ans += i - j
ans += n
pil.compile(foo, 4, 5)
def test_ir_loop_carry_used_after():
"""Carry variable used after the loop produces a scalar op."""
def foo(n):
ans = 0
for i in pypto.loop(n):
ans += i
ans = ans + 10
func = pil.compile(foo, 10)
for_stmts = _for_ops_of(func.body)
assert len(for_stmts) == 1
assert len(for_stmts[0].iter_args) == 2
def test_ir_loop_body_multiple_ops():
"""i, a and ans are all carried."""
def foo(n):
ans = 0
for i in pypto.loop(n):
a = i * 2
ans += a
ans = ans + a
func = pil.compile(foo, 10)
for_stmts = _for_ops_of(func.body)
assert len(for_stmts) == 1
f = for_stmts[0]
assert len(f.iter_args) == 3
assert len(f.return_vars) == 3
def test_ir_deadcode():
def foo(n):
for i in pypto.loop(n):
x = i + 1
break
return x
pil.compile(foo, 10)
def _collect_stmts(stmt, cls):
"""Recursively collect all stmts of a given ir."""
result = []
if isinstance(stmt, cls):
result.append(stmt)
if isinstance(stmt, ir.SeqStmts):
for s in stmt.stmts:
result.extend(_collect_stmts(s, cls))
if isinstance(stmt, ir.IfStmt):
result.extend(_collect_stmts(stmt.then_body, cls))
result.extend(_collect_stmts(stmt.else_body, cls))
if isinstance(stmt, ir.ForStmt):
result.extend(_collect_stmts(stmt.body, cls))
return result
def test_ir_loop_break():
"""break inside loop body generates BreakStmt."""
def foo(n):
ans = 0
for i in pypto.loop(n):
ans += i
break
func = pil.compile(foo, 10)
for_stmts = _for_ops_of(func.body)
assert len(for_stmts) == 1
breaks = _collect_stmts(for_stmts[0].body, ir.BreakStmt)
assert len(breaks) == 1
def test_ir_loop_continue():
"""continue inside loop body generates ContinueStmt."""
def foo(n):
ans = 0
for i in pypto.loop(n):
continue
ans += i
func = pil.compile(foo, 10)
for_stmts = _for_ops_of(func.body)
assert len(for_stmts) == 1
continues = _collect_stmts(for_stmts[0].body, ir.ContinueStmt)
assert len(continues) == 1
def test_ir_loop_return():
"""return inside loop body generates ReturnStmt."""
def foo(n):
ans = 0
for i in pypto.loop(n):
ans += i
return
func = pil.compile(foo, 10)
for_stmts = _for_ops_of(func.body)
assert len(for_stmts) == 1
returns = _collect_stmts(for_stmts[0].body, ir.ReturnStmt)
assert len(returns) == 1
def test_ir_loop_break_in_if():
"""break inside if_else generates IfStmt with BreakStmt in then branch."""
def foo(n):
ans = 0
for i in pypto.loop(n):
if i:
break
ans += i
func = pil.compile(foo, 10)
for_stmts = _for_ops_of(func.body)
assert len(for_stmts) == 1
body = for_stmts[0].body
if_stmts = _collect_stmts(body, ir.IfStmt)
assert len(if_stmts) >= 1
then_breaks = _collect_stmts(if_stmts[0].then_body, ir.BreakStmt)
assert len(then_breaks) == 1
def test_ir_loop_continue_in_if():
"""continue inside if_else generates IfStmt with ContinueStmt."""
def foo(n):
ans = 0
for i in pypto.loop(n):
if i:
continue
ans += i
func = pil.compile(foo, 10)
for_stmts = _for_ops_of(func.body)
assert len(for_stmts) == 1
if_stmts = _collect_stmts(for_stmts[0].body, ir.IfStmt)
assert len(if_stmts) >= 1
then_conts = _collect_stmts(if_stmts[0].then_body, ir.ContinueStmt)
assert len(then_conts) == 1
def test_ir_loop_if_else_both_branches():
"""if/else inside loop compiles both branches."""
def foo(n):
ans = 0
for i in pypto.loop(n):
if i:
ans += i
else:
ans += 1
func = pil.compile(foo, 10)
for_stmts = _for_ops_of(func.body)
assert len(for_stmts) == 1
if_stmts = _collect_stmts(for_stmts[0].body, ir.IfStmt)
assert len(if_stmts) >= 1
def test_tensor_add_dyn():
"""Add dynamic tensor should be supported."""
def foo(x, y):
for i in pypto.loop(x.shape[0] // 32):
pypto.set_vec_tile_shapes(32, 32)
ta = x[i:i + 32, :]
if i % 2 == 0:
y[i:, :] = ta + 1
else:
y[i:, :] = ta - 1
x = pypto.Tensor((-1, 32), pypto.DT_FP32, 'x')
y = pypto.Tensor((-1, 32), pypto.DT_FP32, 'y')
pil.compile(foo, x, y)
def test_fstring():
def foo(x, y):
assert f"{x + y=} , {x - y=}" == "x + y=30 , x - y=-10"
assert f"x={x}, y={y}" == "x=10, y=20"
assert f"sum={x + y}, diff={x - y}, prod={x * y}" == "sum=30, diff=-10, prod=200"
assert f"x={x:05d}, y={y:05d}" == "x=00010, y=00020"
assert f"mod={x % 3}, pow={x ** 2}" == "mod=1, pow=100"
assert f"result: ({x} + {y}) = {x + y}" == "result: (10 + 20) = 30"
assert f"{x!r}, {y!s}" == "10, 20"
pil.compile(foo, 10, 20)