import triton
import triton.language as tl
import torch
import torch_npu
import pytest
import test_common
import triton.language.extra.ascend.libdevice as libdevice
import numpy as np
from scipy.special import gamma
@triton.jit
def triton_gamma(in_ptr0, out_ptr0, xnumel, XBLOCK: tl.constexpr, XBLOCK_SUB: tl.constexpr):
xoffset = tl.program_id(0) * XBLOCK
for xoffset_sub in range(0, XBLOCK, XBLOCK_SUB):
xindex = xoffset + xoffset_sub + tl.arange(0, XBLOCK_SUB)[:]
xmask = xindex < xnumel
x0 = tl.load(in_ptr0 + xindex, xmask)
y = libdevice.gamma(x0)
tl.store(out_ptr0 + xindex, y, xmask)
@pytest.mark.parametrize('param_list',
[
['float32', (2, 4096, 8), 2, 32768, 1024],
]
)
def test_gamma_case(param_list):
dtype, shape, ncore, xblock, xblock_sub = param_list
x = torch.abs(test_common.generate_tensor(shape, dtype))
x_np = x.cpu().numpy()
x = x.npu()
y_ref = torch.from_numpy(gamma(x_np)).to(x.device).to(x.dtype).npu()
y_cal = torch.zeros(shape, dtype=eval('torch.' + dtype)).npu()
triton_gamma[ncore, 1, 1](x, y_cal, x.numel(), xblock, xblock_sub)
test_common.validate_cmp(dtype, y_cal, y_ref)
@pytest.mark.parametrize('param_list',
[
['float32', (2, 4096, 8), 2, 32768, 1024],
]
)
def test_all_blocks_parallel(param_list, monkeypatch):
monkeypatch.setenv("TRITON_ALL_BLOCKS_PARALLEL", "1")
dtype, shape, ncore, xblock, xblock_sub = param_list
x = torch.abs(test_common.generate_tensor(shape, dtype))
x_np = x.cpu().numpy()
x = x.npu()
y_ref = torch.from_numpy(gamma(x_np)).to(x.device).to(x.dtype).npu()
y_cal = torch.zeros(shape, dtype=eval('torch.' + dtype)).npu()
triton_gamma[ncore, 1, 1](x, y_cal, x.numel(), xblock, xblock_sub)
test_common.validate_cmp(dtype, y_cal, y_ref)
monkeypatch.delenv("TRITON_ALL_BLOCKS_PARALLEL")