import pytest
import triton
import triton.language as tl
import test_common
import torch
import torch_npu
types_all = [
(torch.float32, 'float32'),
]
shapes_common = [
(128, 256),
(127, 256),
(127, 16),
(129, 256),
(77, 1024),
(69, 512),
(512, 512)
]
block_size = [
128,
256,
1024
]
def ceil_div(a, b):
return (a + b - 1) // b
def profiler_wrapper(fn, *args):
result_path = "./result_profiling_tl_where"
skip_first = 10
wait = 0
warmup = 3
active = 30
repeat = 1
stream = torch.npu.current_stream()
experimental_config = torch_npu.profiler._ExperimentalConfig(
aic_metrics=torch_npu.profiler.AiCMetrics.PipeUtilization,
profiler_level=torch_npu.profiler.ProfilerLevel.Level1,
l2_cache=False,
data_simplification=False
)
with torch_npu.profiler.profile(
activities=[
torch_npu.profiler.ProfilerActivity.CPU,
torch_npu.profiler.ProfilerActivity.NPU
],
schedule=torch_npu.profiler.schedule(wait=wait, warmup=warmup, active=active, repeat=repeat,
skip_first=skip_first),
on_trace_ready=torch_npu.profiler.tensorboard_trace_handler(result_path),
record_shapes=True,
profile_memory=False,
with_stack=False,
with_flops=False,
with_modules=False,
experimental_config=experimental_config) as prof:
stream.synchronize()
for _ in range(skip_first + (wait + warmup + active) * repeat):
fn(*args)
prof.step()
stream.synchronize()
@triton.jit
def tl_where_kernel(
in_ptr,
output_ptr,
N: tl.constexpr,
M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
):
pid = tl.program_id(axis=0)
offset = tl.multiple_of(pid * BLOCK_SIZE_N, N)
x1 = (offset + tl.arange(0, BLOCK_SIZE_N)) // N
mask1 = tl.where(x1 < M, 1, 0).to(tl.int1)
data = tl.load(in_ptr + x1 * N, mask=mask1, other=0)
x2 = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
tl.store(output_ptr + x2, data)
def torch_tl_where(in_tensor):
M = in_tensor.shape[0] // 2
N = in_tensor.shape[1]
output = torch.zeros_like(in_tensor)
first_elements = in_tensor[:M, 0:1]
output[:M] = first_elements.expand(-1, N)
return output
@pytest.mark.parametrize('dtype, sigtype', types_all)
@pytest.mark.parametrize('M, N', shapes_common)
@pytest.mark.parametrize('BLOCK_SIZE_N', block_size)
def test_tl_where(M, N, BLOCK_SIZE_N, dtype, sigtype):
in_tensor = torch.randn(2 * M, N, dtype=dtype).npu()
triton_output = torch.zeros_like(in_tensor)
grid = (ceil_div(2 * M * N, BLOCK_SIZE_N),)
tl_where_kernel[grid](
in_tensor,
triton_output,
N=N,
M=M,
BLOCK_SIZE_N=BLOCK_SIZE_N,
optimize_dynamic_offset=False
)
torch_output = torch_tl_where(in_tensor.clone())
assert torch.allclose(triton_output, torch_output, rtol=1e-5, atol=1e-8)
def triton_tl_where(in_tensor, BLOCK_SIZE):
M = in_tensor.shape[0] // 2
N = in_tensor.shape[1]
triton_output = torch.zeros_like(in_tensor)
grid = (ceil_div(2 * M * N, BLOCK_SIZE),)
tl_where_kernel[grid](
in_tensor,
triton_output,
N=N,
M=M,
BLOCK_SIZE_N=BLOCK_SIZE,
optimize_dynamic_offset=True
)
def profile_performance_test(M, N, dtype, BLOCK_SIZE):
print(f"\nDetailed performance analysis: M={M}, N={N}, dtype={dtype}, block_size={BLOCK_SIZE}")
in_tensor = torch.randn(2 * M, N, dtype=dtype).npu()
def wrapper_func(x):
triton_tl_where(x, BLOCK_SIZE=BLOCK_SIZE)
profiler_wrapper(wrapper_func, in_tensor)
if __name__ == "__main__":
profile_performance_test(512, 512, torch.float32, BLOCK_SIZE=1024)
print("\n" + "=" * 80)
print("Test completed!")
print(f"Detailed performance analysis results saved in: ./result_profiling_tl_where/")
print("=" * 80)