import math
import pytest
import torch
import triton
import triton.language as tl
import test_common
from test_common import TestUtils
filtered_dtype = [dtype for dtype in TestUtils.full_dtype if dtype not in {'uint32', 'float16', 'float32', 'bfloat16', 'int64', 'bool'}]
@triton.jit
def atomic_or(in_ptr0, out_ptr0, n_elements, BLOCK_SIZE: tl.constexpr, BLOCK_NUM: tl.constexpr):
in_offset = tl.program_id(0) * BLOCK_SIZE
out_offset = (tl.program_id(0) % BLOCK_NUM) * BLOCK_SIZE
in_index = in_offset + tl.arange(0, BLOCK_SIZE)
out_index = out_offset + tl.arange(0, BLOCK_SIZE)
xmask = in_index < n_elements
tmp0 = tl.load(in_ptr0 + (in_index), xmask)
tl.atomic_or(out_ptr0 + (out_index), tmp0, xmask)
@triton.jit
def atomic_or_ndim(x_ptr, out_ptr, NCORE: tl.constexpr, BLOCK_SIZE: tl.constexpr,
DIM0: tl.constexpr, DIM1: tl.constexpr, DIM2: tl.constexpr, DIM3: tl.constexpr, DIM4: tl.constexpr):
sub_idx = tl.program_id(1)
base_src = tl.program_id(0) * DIM4 + sub_idx * BLOCK_SIZE
base_dst = (tl.program_id(0) % (DIM0 * DIM1 * DIM2 * DIM3)) * DIM4 + sub_idx * BLOCK_SIZE
offsets_src = tl.arange(0, BLOCK_SIZE) + base_src
offsets_dst = tl.arange(0, BLOCK_SIZE) + base_dst
mask = tl.arange(0, BLOCK_SIZE) + sub_idx * BLOCK_SIZE < DIM4
tmp = tl.load(x_ptr + offsets_src, mask)
tl.atomic_or(out_ptr + offsets_dst, tmp, mask)
@triton.jit
def atomic_or_broadcast(x_ptr, y_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
pid = tl.program_id(0)
x = tl.load(x_ptr)
y_offset = pid * BLOCK_SIZE
y_indices = y_offset + tl.arange(0, BLOCK_SIZE)
y_mask = y_indices < n_elements
y_value = tl.load(y_ptr + y_indices, y_mask)
tl.atomic_or(out_ptr + y_indices, y_value, mask=y_mask)
tl.atomic_or(out_ptr + y_indices, x, mask=y_mask)
test_cases = [
((1, 1, 1, 1), (1, 1, 1, 4), 4),
((1, 1, 1, 3), (1, 5, 1, 3), 5),
((3,), (2, 3, 3, 3, 3), 81),
((3,), (2, 3, 3, 3), 27),
((3,), (2, 3, 3), 9),
((3,), (2, 3), 3),
]
@pytest.mark.parametrize('shape', TestUtils.test_shape2d + TestUtils.test_shape1d)
@pytest.mark.parametrize('x_dtype_str', filtered_dtype)
def test_atomic_or(x_dtype_str, shape):
x_dtype = eval('torch.' + x_dtype_str)
x_shape = list(shape[:])
x_shape[0] *= 2
x = test_common.generate_tensor(x_shape, x_dtype_str).npu()
y = torch.full(shape, 0, dtype=x_dtype).npu()
x_temp = x.clone()
y_temp = y.clone()
if len(shape) == 2:
n_elements = shape[0] * shape[1] * 2
atomic_or[shape[0] * 2, 1, 1](x, y, n_elements, BLOCK_SIZE=shape[1], BLOCK_NUM=shape[0])
elif len(shape) == 1:
n_elements = shape[0]
BLOCK_SIZE = min(1024, shape[0])
grid_size = (n_elements + BLOCK_SIZE - 1) // BLOCK_SIZE
aligned_size = grid_size * BLOCK_SIZE
x_concat = torch.full([aligned_size * 2], 0, dtype=x_dtype).npu()
x_concat[0:n_elements] = x[0:n_elements]
x_concat[aligned_size:(aligned_size + n_elements)] = x[n_elements:(n_elements * 2)]
atomic_or[grid_size * 2, 1, 1](x_concat, y, aligned_size * 2, BLOCK_SIZE=BLOCK_SIZE, BLOCK_NUM=grid_size)
expected = y_temp | x_temp[0:shape[0]] | x_temp[shape[0]:(shape[0] * 2)]
torch.testing.assert_close(y, expected)
@pytest.mark.parametrize('shape', TestUtils.test_shape3d)
@pytest.mark.parametrize('x_dtype_str', filtered_dtype)
def test_atomic_or_3d(x_dtype_str, shape):
x_dtype = eval('torch.' + x_dtype_str)
x_shape = list(shape[:])
x_shape[0] *= 2
x = test_common.generate_tensor(x_shape, x_dtype_str).npu()
y = torch.full(shape, 0, dtype=x_dtype).npu()
x_temp = x.clone()
y_temp = y.clone()
n_elements = shape[0] * shape[1] * shape[2]
atomic_or[2, 1, 1](x, y, n_elements * 2, BLOCK_SIZE=shape[0] * shape[1] * shape[2], BLOCK_NUM=1)
expected = y_temp | x_temp[0:shape[0]] | x_temp[shape[0]:(shape[0] * 2)]
torch.testing.assert_close(y, expected)
@triton.jit
def atomic_or_multi_d(in_ptr0, out_ptr0, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, MB: tl.constexpr, NB: tl.constexpr):
offsets = tl.arange(0, XB) * (YB * ZB * MB * NB)
if (YB * ZB * MB * NB) > 1:
offsets = offsets[:, None] + tl.arange(0, YB)[None, :] * (ZB * MB * NB)
if (ZB * MB * NB) > 1:
offsets = offsets[:, :, None] + tl.arange(0, ZB)[None, None, :] * (MB * NB)
if (MB * NB) > 1:
offsets = offsets[:, :, :, None] + tl.arange(0, MB)[None, None, None, :] * NB
if NB > 1:
offsets = offsets[:, :, :, :, None] + tl.arange(0, NB)[None, None, None, None, :]
tmp0 = tl.load(in_ptr0 + offsets)
tl.atomic_or(out_ptr0 + offsets, tmp0)
@pytest.mark.shape_4d_5d
@pytest.mark.parametrize('shape', [
(2, 4, 8, 4),
(8, 4, 2, 4),
(2, 8, 2, 2),
(2, 4, 8, 4, 2),
(8, 4, 2, 4, 4),
(2, 8, 2, 2, 2),
])
@pytest.mark.parametrize('dtype', filtered_dtype)
def test_atomic_or_4d_5d(dtype, shape):
x0_value = 3
x0 = torch.full(shape, x0_value, dtype=eval('torch.' + dtype)).npu()
x1 = torch.full(shape, 2, dtype=eval('torch.' + dtype)).npu()
x1_ref = x1 | x0_value
triton_shape = [*shape]
while len(triton_shape) < 5:
triton_shape.append(1)
atomic_or_multi_d[(1, )](x0, x1, *triton_shape)
test_common.validate_cmp(dtype, x1, x1_ref)
@pytest.mark.parametrize('shape',
[
(1, 1, 1, 1, 2),
(10, 1, 15, 1, 7),
(1, 1, 1, 1, 257),
]
)
@pytest.mark.parametrize('x_dtype_str', filtered_dtype)
def test_atomic_or_5d(x_dtype_str, shape):
shape = shape
x_dtype = eval('torch.' + x_dtype_str)
x_shape = list(shape[:])
x_shape[0] *= 2
x = torch.randint(low=0, high=100, size=x_shape, dtype=x_dtype).npu()
out = torch.full(shape, 0, dtype=x_dtype).npu()
x_temp = x.clone()
out_temp = out.clone()
triton_shape = [*shape]
while len(triton_shape) < 5:
triton_shape.append(1)
XB, YB, ZB, MB, NB = triton_shape
BLOCK_SIZE = 256
ncore = (NB + BLOCK_SIZE - 1) // BLOCK_SIZE
atomic_or_ndim[(2 * XB * YB * ZB * MB, ncore)](
x_ptr=x,
out_ptr=out,
NCORE=ncore,
BLOCK_SIZE=BLOCK_SIZE,
DIM0=XB, DIM1=YB, DIM2=ZB, DIM3=MB, DIM4=NB,
)
expected = out_temp | x_temp[0:shape[0]] | x_temp[shape[0]:x_shape[0]]
torch.testing.assert_close(out, expected)
@pytest.mark.parametrize('shape',
[
(1, 1, 1, 1),
(1, 1, 2, 2),
(1, 3, 2, 7),
(1, 3, 2, 651),
]
)
@pytest.mark.parametrize('x_dtype_str', filtered_dtype)
def test_atomic_or_4d(x_dtype_str, shape):
shape = shape
x_dtype = eval('torch.' + x_dtype_str)
x_shape = list(shape[:])
x_shape[0] *= 2
x = torch.randint(low=0, high=100, size=x_shape, dtype=x_dtype).npu()
out = torch.full(shape, 0, dtype=x_dtype).npu()
x_temp = x.clone()
out_temp = out.clone()
triton_shape = [*shape]
while len(triton_shape) < 4:
triton_shape.append(1)
XB, YB, ZB, MB = triton_shape
BLOCK_SIZE = 256
ncore = (MB + BLOCK_SIZE - 1) // BLOCK_SIZE
atomic_or_ndim[(2 * XB * YB * ZB, ncore)](
x_ptr=x,
out_ptr=out,
NCORE=ncore,
BLOCK_SIZE=BLOCK_SIZE,
DIM0=1, DIM1=XB, DIM2=YB, DIM3=ZB, DIM4=MB,
)
expected = out_temp | x_temp[0:shape[0]] | x_temp[shape[0]:x_shape[0]]
torch.testing.assert_close(out, expected)
@pytest.mark.parametrize('shape',
[
(1, 1, 1),
(1, 1, 2),
(1, 31, 275),
]
)
@pytest.mark.parametrize('x_dtype_str', filtered_dtype)
def test_atomic_or_3d_2(x_dtype_str, shape):
shape = shape
x_dtype = eval('torch.' + x_dtype_str)
x_shape = list(shape[:])
x_shape[0] *= 2
x = torch.randint(low=0, high=100, size=x_shape, dtype=x_dtype).npu()
out = torch.full(shape, 0, dtype=x_dtype).npu()
x_temp = x.clone()
out_temp = out.clone()
triton_shape = [*shape]
while len(triton_shape) < 3:
triton_shape.append(1)
XB, YB, ZB = triton_shape
BLOCK_SIZE = 256
ncore = (ZB + BLOCK_SIZE - 1) // BLOCK_SIZE
atomic_or_ndim[(2 * XB * YB, ncore)](
x_ptr=x,
out_ptr=out,
NCORE=ncore,
BLOCK_SIZE=BLOCK_SIZE,
DIM0=1, DIM1=1, DIM2=XB, DIM3=YB, DIM4=ZB,
)
expected = out_temp | x_temp[0:shape[0]] | x_temp[shape[0]:x_shape[0]]
torch.testing.assert_close(out, expected)
@pytest.mark.parametrize('shape',
[
(1, 2),
(1, 1),
(257, 1),
(257, 2),
]
)
@pytest.mark.parametrize('x_dtype_str', filtered_dtype)
def test_atomic_or_2d(x_dtype_str, shape):
shape = shape
x_dtype = eval('torch.' + x_dtype_str)
x_shape = list(shape[:])
x_shape[0] *= 2
x = torch.randint(low=0, high=100, size=x_shape, dtype=x_dtype).npu()
out = torch.full(shape, 0, dtype=x_dtype).npu()
x_temp = x.clone()
out_temp = out.clone()
triton_shape = [*shape]
while len(triton_shape) < 2:
triton_shape.append(1)
XB, YB = triton_shape
BLOCK_SIZE = 256
ncore = (YB + BLOCK_SIZE - 1) // BLOCK_SIZE
atomic_or_ndim[(2 * XB, ncore)](
x_ptr=x,
out_ptr=out,
NCORE=ncore,
BLOCK_SIZE=BLOCK_SIZE,
DIM0=1, DIM1=1, DIM2=1, DIM3=XB, DIM4=YB,
)
expected = out_temp | x_temp[0:shape[0]] | x_temp[shape[0]:x_shape[0]]
torch.testing.assert_close(out, expected)
@pytest.mark.parametrize('shape', [(1,), (9,), (256,), (257,), (65535,), (65536,)])
@pytest.mark.parametrize('x_dtype_str', filtered_dtype)
def test_atomic_or_1d(x_dtype_str, shape):
shape = shape
x_dtype = eval('torch.' + x_dtype_str)
x_shape = list(shape[:])
x_shape[0] *= 2
x = torch.randint(low=0, high=100, size=x_shape, dtype=x_dtype).npu()
out = torch.full(shape, 0, dtype=x_dtype).npu()
x_temp = x.clone()
out_temp = out.clone()
triton_shape = [*shape]
while len(triton_shape) < 2:
triton_shape.append(1)
XB = triton_shape[0]
BLOCK_SIZE = 256
ncore = (XB + BLOCK_SIZE - 1) // BLOCK_SIZE
atomic_or_ndim[(2, ncore)](
x_ptr=x,
out_ptr=out,
NCORE=ncore,
BLOCK_SIZE=BLOCK_SIZE,
DIM0=1, DIM1=1, DIM2=1, DIM3=1, DIM4=XB,
)
expected = out_temp | x_temp[0:shape[0]] | x_temp[shape[0]:x_shape[0]]
torch.testing.assert_close(out, expected)
@pytest.mark.parametrize('param_list',
[
['uint8', (32, 32), 2],
['uint16', (32, 32), 2],
['uint32', (32, 32), 2],
['uint64', (32, 32), 2],
]
)
def test_atomic_or_uint(param_list):
dtype, shape, ncore = param_list
block_size = shape[0] * shape[1] // ncore
split_size = shape[0] // ncore
val_cpu = torch.randint(low=0, high=10, size=shape, dtype=eval(f'torch.{dtype}')).cpu()
val = val_cpu.to("npu")
pointer_cpu = torch.randint(low=0, high=10, size=(split_size, shape[1]), dtype=eval(f'torch.{dtype}')).cpu()
pointer = pointer_cpu.to("npu")
pointer_old_cpu = torch.full_like(pointer_cpu, -10).cpu()
pointer_old = pointer_old_cpu.to("npu")
pointer_ref_cpu = pointer_cpu.clone()
for i in range(ncore - 1):
pointer_ref_cpu |= val_cpu[(i * split_size):((i + 1) * split_size)]
pointer_ref_last = pointer_ref_cpu.clone()
pointer_ref_cpu |= val_cpu[((ncore - 1) * split_size):(ncore * split_size)]
pointer_ref = pointer_ref_cpu.to("npu")
@triton.jit
def atomic_or_uint(in_ptr0, out_ptr0, out_ptr1, n_elements, BLOCK_SIZE: tl.constexpr):
xoffset = tl.program_id(0) * BLOCK_SIZE
xindex = xoffset + tl.arange(0, BLOCK_SIZE)[:]
yindex = tl.arange(0, BLOCK_SIZE)[:]
xmask = xindex < n_elements
x0 = xindex
x1 = yindex
tmp0 = tl.load(in_ptr0 + (x0), xmask)
tmp1 = tl.atomic_or(out_ptr0 + (x1), tmp0, xmask)
tl.store(out_ptr1 + (x1), tmp1, xmask)
n_elements = shape[0] * shape[1]
atomic_or_uint[ncore, 1, 1](val, pointer, pointer_old, n_elements, BLOCK_SIZE=split_size * shape[1])
test_common.validate_cmp(dtype, pointer, pointer_ref)