import asc
from asc.runtime import config
from asc.runtime.jit import MockTensor
def setup_function():
config.set_platform(config.Backend.Model, check=False)
def test_notify_next_block(mock_launcher_run):
@asc.jit
def kernel_notify_next_block(gm_workspace_addr: asc.GlobalAddress) -> None:
gm_workspace = asc.GlobalTensor()
gm_workspace.set_global_buffer(gm_workspace_addr)
ub_workspace = asc.LocalTensor(dtype=asc.int32, pos=asc.TPosition.VECIN, addr=0, tile_size=32)
asc.notify_next_block(gm_workspace, ub_workspace)
workspace_tensor = MockTensor(asc.int32)
kernel_notify_next_block[1](workspace_tensor)
assert mock_launcher_run.call_count == 1
def test_wait_pre_block(mock_launcher_run):
@asc.jit
def kernel_wait_pre_block(gm_workspace_addr: asc.GlobalAddress) -> None:
gm_workspace = asc.GlobalTensor()
gm_workspace.set_global_buffer(gm_workspace_addr)
ub_workspace = asc.LocalTensor(dtype=asc.int32, pos=asc.TPosition.VECIN, addr=0, tile_size=32)
asc.wait_pre_block(gm_workspace, ub_workspace)
workspace_tensor = MockTensor(asc.int32)
kernel_wait_pre_block[1](workspace_tensor)
assert mock_launcher_run.call_count == 1