import triton
import triton.language as tl
import torch
import pytest
import math
import test_common
@triton.jit
def histogram_kernel(x_ptr, z_ptr, M: tl.constexpr, N: tl.constexpr):
offset1 = tl.arange(0, M)
offset2 = tl.arange(0, N)
x = tl.load(x_ptr + offset1)
z = tl.histogram(x, N)
tl.store(z_ptr + offset2, z)
@pytest.mark.parametrize("M", [2048])
@pytest.mark.parametrize("N", [2])
@pytest.mark.parametrize("ncore", [2])
@pytest.mark.parametrize("dtype", ["int32","int64"])
def test_histogram(M, N, ncore, dtype):
torch.manual_seed(17)
x = torch.randint(low=0, high=N, size=(M,), dtype=eval(f'torch.{dtype}')).npu()
y_cal = torch.histc(x.float(), bins=N, min=0, max=N - 1)
y_ref = torch.empty(N, dtype=eval(f'torch.{dtype}'), device="npu")
histogram_kernel[(ncore, )](x, y_ref, M=M, N=N)
print(y_cal)
print(y_ref)
test_common.validate_cmp(dtype, y_cal, y_ref)
@pytest.mark.parametrize("M", [2048])
@pytest.mark.parametrize("N", [2])
@pytest.mark.parametrize("ncore", [2])
@pytest.mark.parametrize("dtype", [ "uint32", "uint64"])
def test_histogram_uint(M, N, ncore, dtype):
torch.manual_seed(17)
x_cpu = torch.randint(low=0, high=N, size=(M,), dtype=eval(f'torch.{dtype}'), device="cpu")
x = x_cpu.to("npu")
y_cal = torch.histc(x.float(), bins=N, min=0, max=N - 1)
y_cal = y_cal.to(eval(f'torch.{dtype}'))
y_ref = torch.empty(N, dtype=eval(f'torch.{dtype}'), device="npu")
histogram_kernel[(ncore, )](x, y_ref, M=M, N=N)
test_common.validate_cmp(dtype, y_cal, y_ref)