"""
Matmul GATHER_IN_L1_TESTS test script.
Supports both pytest and direct execution modes.
"""
import os
import pypto
import pytest
import torch
from testcase.gather_in_l1_test_case import GATHER_IN_L1_TESTS, GatherInL1Config
@pypto.frontend.jit(debug_options={"runtime_debug_mode": 0, "compile_debug_mode": 0})
def gather_matmul_pto_kernel(
src: pypto.Tensor(),
indices: pypto.Tensor(),
page_table: pypto.Tensor(),
mat2_tensor: pypto.Tensor(),
out_tensor: pypto.Tensor(),
config: GatherInL1Config,
):
pypto.set_cube_tile_shapes([128, 128], [128, 128], [128, 128])
dyn_src = pypto.view(src, src.shape, [0, 0], valid_shape=src.shape)
dyn_offsets = pypto.view(indices, indices.shape, [0, 0], valid_shape=indices.shape)
dyn_page_table = pypto.view(page_table, page_table.shape, [0, 0], valid_shape=page_table.shape)
input_tensor = pypto.experimental.gather_in_l1(
dyn_src, dyn_offsets, dyn_page_table, config.block_size, config.token_dim,
is_b_matrix=config.is_gather_b_matrix,
is_trans=config.is_gather_trans
)
if not config.is_gather_b_matrix:
out_tensor[:] = pypto.matmul(input_tensor, mat2_tensor, config.out_dtype,
a_trans=config.is_gather_trans, b_trans=False)
else:
out_tensor[:] = pypto.matmul(mat2_tensor, input_tensor, config.out_dtype,
a_trans=False, b_trans=config.is_gather_trans)
def run_gather_in_l1_test(case: dict):
device_id = int(os.environ.get("TILE_FWK_DEVICE_ID", 0))
torch.npu.set_device(device_id)
torch.manual_seed(42)
config = GatherInL1Config.from_test_case(case)
in_dtype = GatherInL1Config.get_torch_dtype(case["in_dtype"])
out_dtype = GatherInL1Config.get_torch_dtype(case["out_dtype"])
if case["in_dtype"] == "DT_INT8":
all_buffer = torch.randint(-5, 5, (config.num_buffer_tokens, config.token_dim), dtype=in_dtype)
else:
all_buffer = torch.randn(config.num_buffer_tokens, config.token_dim, dtype=in_dtype)
if config.is_gather_b_matrix == config.is_gather_trans:
unit_tensor = torch.ones((config.token_dim, config.token_dim), dtype=in_dtype)
else:
unit_tensor = torch.ones((config.topk, config.topk), dtype=in_dtype)
page_table = torch.randperm(
config.num_buffer_tokens // config.block_size)[:config.num_logical_blocks].to(torch.int32).view(1, -1)
indices = torch.randint(
0, config.block_size * page_table.shape[1], (config.topk,)).to(torch.int32).view(1, -1)
extracted = page_table.view(-1)[indices.view(-1) // config.block_size] * config.block_size + \
indices.view(-1) % config.block_size
extracted_tokens = all_buffer[extracted]
if config.is_gather_b_matrix and not config.is_gather_trans:
golden = torch.matmul(unit_tensor, extracted_tokens).to(out_dtype)
elif config.is_gather_b_matrix and config.is_gather_trans:
golden = torch.matmul(unit_tensor, extracted_tokens.T).to(out_dtype)
elif not config.is_gather_b_matrix and not config.is_gather_trans:
golden = torch.matmul(extracted_tokens, unit_tensor).to(out_dtype)
else:
golden = torch.matmul(extracted_tokens.T, unit_tensor).to(out_dtype)
out_tensor = torch.zeros(golden.shape, dtype=out_dtype, device=f"npu:{device_id}")
gather_matmul_pto_kernel(all_buffer.npu(), indices.npu(), page_table.npu(), unit_tensor.npu(),
out_tensor.npu(), config)
atol, rtol = GatherInL1Config.get_tolerance(case["out_dtype"])
assert torch.allclose(
out_tensor.cpu(), golden.cpu(), atol=atol, rtol=rtol
), f"Test case {case['id']} ({case['name']}) failed"
@pytest.mark.parametrize("case", [
pytest.param(case, marks=pytest.mark.soc(*case["products"]))
for case in GATHER_IN_L1_TESTS
])
def test_gather_in_l1_basic(case: dict):
run_gather_in_l1_test(case)
def create_gather_demo_inputs(device: str, config: GatherInL1Config):
all_buffer = torch.randn(config.num_buffer_tokens, config.token_dim, dtype=torch.float16, device=device)
if config.is_gather_b_matrix == config.is_gather_trans:
unit_tensor = torch.ones((config.token_dim, config.token_dim), dtype=torch.float16, device=device)
else:
unit_tensor = torch.ones((config.topk, config.topk), dtype=torch.float16, device=device)
page_table = torch.randperm(
config.num_buffer_tokens // config.block_size)[:config.num_logical_blocks].to(torch.int32).view(1, -1)
page_table = page_table.to(device)
indices = torch.randint(
0, config.block_size * page_table.shape[1], (config.topk,)).to(torch.int32).view(1, -1)
indices = indices.to(device)
return all_buffer, indices, page_table, unit_tensor
def run_gather_in_l1_demo(run_mode):
config = GatherInL1Config(
topk=8,
num_logical_blocks=3,
num_buffer_tokens=32,
token_dim=30,
block_size=10,
is_gather_b_matrix=False,
is_gather_trans=False,
out_dtype=pypto.DT_FP16
)
if run_mode == "npu":
mode = pypto.RunMode.NPU
elif run_mode == "sim":
mode = pypto.RunMode.SIM
else:
raise ValueError(f"Invalid run_mode: {run_mode}. Must be 'npu' or 'sim'")
@pypto.frontend.jit(
debug_options={"runtime_debug_mode": 1, "compile_debug_mode": 1},
runtime_options={"run_mode": mode}
)
def gather_demo_kernel(
src: pypto.Tensor([], pypto.DT_FP16),
indices: pypto.Tensor([], pypto.DT_INT32),
page_table: pypto.Tensor([], pypto.DT_INT32),
mat2: pypto.Tensor([], pypto.DT_FP16),
out: pypto.Tensor([], pypto.DT_FP16),
):
pypto.set_cube_tile_shapes([128, 128], [128, 128], [128, 128])
dyn_offsets = pypto.view(indices, indices.shape, [0, 0], valid_shape=indices.shape)
dyn_src = pypto.view(src, src.shape, [0, 0], valid_shape=src.shape)
dyn_page_table = pypto.view(page_table, page_table.shape, [0, 0], valid_shape=page_table.shape)
input_tensor = pypto.experimental.gather_in_l1(
dyn_src, dyn_offsets, dyn_page_table, config.block_size, config.token_dim,
is_b_matrix=config.is_gather_b_matrix, is_trans=config.is_gather_trans
)
if not config.is_gather_b_matrix:
out[:] = pypto.matmul(input_tensor, mat2, pypto.DT_FP16,
a_trans=config.is_gather_trans, b_trans=False)
else:
out[:] = pypto.matmul(mat2, input_tensor, pypto.DT_FP16,
a_trans=False, b_trans=config.is_gather_trans)
device = "npu:0" if run_mode == "npu" else "cpu"
all_buffer, indices, page_table, unit_tensor = create_gather_demo_inputs(device, config)
out_shape = (config.topk, config.token_dim) if not config.is_gather_b_matrix else (config.token_dim, config.topk)
if config.is_gather_trans:
out_shape = (out_shape[1], out_shape[0])
out_tensor = torch.empty(out_shape, dtype=torch.float16, device=device)
gather_demo_kernel(all_buffer, indices, page_table, unit_tensor, out_tensor)
if __name__ == "__main__":
run_gather_in_l1_demo("npu")