from unittest.mock import patch
import pytest
import asc
from asc.codegen.errors import CodegenError
from asc.runtime.jit import JITFunction, MockTensor
@pytest.fixture(autouse=True)
def mock_jit():
with patch("asc.runtime.jit.JITFunction._run_launcher", return_value=None), \
patch("asc.runtime.jit.JITFunction._run_compiler", return_value=None):
yield
@asc.jit
def func_visit_list(a, b, c):
nums = [a, b, c]
return nums
def test_func_visit_list(filecheck):
@asc.jit
def func_visit_list_kernel():
func_visit_list(1, 2, 3)
filecheck(func_visit_list_kernel)()
def test_func_visit_bool_op(filecheck):
@asc.jit
def func_visit_bool_op_kernel(value, min_threshold, max_threshold, cnt, step):
if value >= min_threshold and value <= max_threshold:
cnt += step
elif value < min_threshold or value > max_threshold:
cnt += step * 2
ret = cnt == step
filecheck(func_visit_bool_op_kernel)(10, 5, 20, 0, 1)
def test_func_visit_compare(filecheck):
@asc.jit
def func_visit_compare_kernel(a, b, c):
ans = 0
if a > b:
ans += 1
if a + b < c:
ans += 1
if ans >= 10:
ans += 1
if ans <= 5:
ans += 1
if ans == c:
ans += 1
ret = ans
filecheck(func_visit_compare_kernel)(2, 1, 2)
@asc.jit
def func_visit_constant():
a = 1
b = True
return a, b
def test_func_visit_constant(filecheck):
@asc.jit
def func_visit_constant_kernel():
func_visit_constant()
filecheck(func_visit_constant_kernel)()
def test_func_visit_if_exp(filecheck):
@asc.jit
def func_visit_if_exp_kernel(x):
ans = x if x > 0 else 0
ret = ans
filecheck(func_visit_if_exp_kernel)(10)
def test_func_visit_if(filecheck):
@asc.jit
def func_visit_if_kernel(x, y, z, ans, step):
if x + y == z:
ans += step
elif x + y > z:
ans += step
else:
ans += step
ret = ans == 1
filecheck(func_visit_if_kernel)(1, 2, 5, 0, 1)
def test_func_visit_pass(filecheck):
@asc.jit
def func_visit_pass():
pass
filecheck(func_visit_pass)()
@asc.jit
def func_visit_return(x):
return x
def test_func_visit_return(filecheck):
@asc.jit
def func_visit_return_kernel():
func_visit_return(100)
filecheck(func_visit_return_kernel)()
@asc.jit
def func_visit_tuple(x, y, z):
return x, y, z
def test_func_visit_tuple(filecheck):
@asc.jit
def func_visit_tuple_kernel():
func_visit_tuple(1, 2, 3)
filecheck(func_visit_tuple_kernel)()
def test_func_visit_unary_op(filecheck):
@asc.jit
def func_visit_unary_op_kernel(a):
ret = a + ~1
filecheck(func_visit_unary_op_kernel)(1)
def test_func_visit_while(filecheck):
@asc.jit
def func_visit_while_kernel(n):
ans = 0
i = 0
while i < n:
ans += i
i += 1
ret = ans
filecheck(func_visit_while_kernel)(10)
def test_func_visit_augassign(filecheck):
@asc.jit
def func_visit_augassign_kernel(a, b):
result = a
result += b
result *= 2
result -= 5
result /= a
result %= 3
ret = result
filecheck(func_visit_augassign_kernel)(2, 5)
def test_func_visit_ann_assign(filecheck):
@asc.jit
def func_visit_ann_assign_kernel(num1, num2, num3):
count: int = 3
sum_result: float = 1.0
sum_result += num1
sum_result += num2
sum_result += num3
result = sum_result / count
filecheck(func_visit_ann_assign_kernel)(12, 18, 100)
def test_func_visit_arguments(filecheck):
@asc.jit
def func_visit_arguments(data: asc.GlobalAddress, threshold: int, flag: bool) -> None:
pass
data = MockTensor(asc.int32)
filecheck(func_visit_arguments)(data, 32, True)
def test_joined_and_formatted_assert():
@asc.jit
def func_visit_joined_and_formatted_assert(num):
assert 1 < 0, f"assert failed {num}"
with pytest.raises(CodegenError) as e:
func_visit_joined_and_formatted_assert[1](100)
assert "AssertionError" in str(e.value)
def test_error_test_str():
@asc.jit
def func_error_test_str(name):
return name
with pytest.raises(TypeError) as e:
func_error_test_str[1]("test")
assert "Argument type in JIT function is not supported: str" in str(e.value)
def test_error_test_print():
@asc.jit
def func_error_test_print(age):
print(age)
with pytest.raises(CodegenError) as e:
func_error_test_print[1](100)
assert "NameError: print is not defined" in str(e.value)