import asc
from asc.runtime import config
from asc.runtime.jit import MockTensor
def setup_function():
config.set_platform(config.Backend.Model, check=False)
def test_set_atomic_add_kernel(mock_launcher_run):
@asc.jit(always_compile=True)
def set_atomic_add_kernel() -> None:
asc.set_atomic_add(asc.half)
set_atomic_add_kernel[1]()
assert mock_launcher_run.call_count == 1
def test_set_atomic_max_kernel(mock_launcher_run):
@asc.jit(always_compile=True)
def set_atomic_max_kernel() -> None:
asc.set_atomic_max(asc.half)
set_atomic_max_kernel[1]()
assert mock_launcher_run.call_count == 1
def test_set_atomic_min_kernel(mock_launcher_run):
@asc.jit(always_compile=True)
def set_atomic_min_kernel() -> None:
asc.set_atomic_min(asc.half)
set_atomic_min_kernel[1]()
assert mock_launcher_run.call_count == 1
def test_set_atomic_none_kernel(mock_launcher_run):
@asc.jit(always_compile=True)
def set_atomic_none_kernel() -> None:
asc.set_atomic_none()
set_atomic_none_kernel[1]()
assert mock_launcher_run.call_count == 1
def test_set_atomic_add_kernel(mock_launcher_run):
@asc.jit(always_compile=True)
def set_atomic_add_kernel() -> None:
asc.set_atomic_add(asc.half)
set_atomic_add_kernel[1]()
assert mock_launcher_run.call_count == 1
def test_set_atomic_max_kernel(mock_launcher_run):
@asc.jit(always_compile=True)
def set_atomic_max_kernel() -> None:
asc.set_atomic_max(asc.half)
set_atomic_max_kernel[1]()
assert mock_launcher_run.call_count == 1
def test_set_atomic_min_kernel(mock_launcher_run):
@asc.jit(always_compile=True)
def set_atomic_min_kernel() -> None:
asc.set_atomic_min(asc.half)
set_atomic_min_kernel[1]()
assert mock_launcher_run.call_count == 1
def test_set_atomic_none_kernel(mock_launcher_run):
@asc.jit(always_compile=True)
def set_atomic_none_kernel() -> None:
asc.set_atomic_none()
set_atomic_none_kernel[1]()
assert mock_launcher_run.call_count == 1
def test_set_atomic_type_kernel(mock_launcher_run):
@asc.jit(always_compile=True)
def set_atomic_type_kernel() -> None:
asc.set_atomic_type(asc.half)
set_atomic_type_kernel[1]()
assert mock_launcher_run.call_count == 1