import logging
import random
import pytest
import triton
import triton.language as tl
import torch
import test_common
import numpy as np
import triton.language.extra.cann.extension as extension
def gen_1d_cat_shapes(min_val=1, max_val=4096):
shape1 = random.randint(min_val, max_val)
shape2 = random.randint(min_val, max_val)
return (shape1,), (shape2,), 0
def gen_2d_cat_shapes(dim=0, min_val=1, max_val=4096):
if dim == 0:
common_col = random.randint(min_val, max_val)
row1 = random.randint(min_val, max_val)
row2 = random.randint(min_val, max_val)
shape1 = (row1, common_col)
shape2 = (row2, common_col)
elif dim == 1:
common_row = random.randint(min_val, max_val)
col1 = random.randint(min_val, max_val)
col2 = random.randint(min_val, max_val)
shape1 = (common_row, col1)
shape2 = (common_row, col2)
else:
raise ValueError("2d shape only support dim=0 or dim=1")
return shape1, shape2, dim
def gen_3d_cat_shapes(dim=0, min_val=1, max_val=4096):
if dim not in [0, 1, 2]:
raise ValueError("3d shape only support dim=0/1/2")
if dim == 0:
common_d1 = random.randint(min_val, max_val)
common_d2 = random.randint(min_val, max_val)
d0_1 = random.randint(min_val, max_val)
d0_2 = random.randint(min_val, max_val)
shape1 = (d0_1, common_d1, common_d2)
shape2 = (d0_2, common_d1, common_d2)
elif dim == 1:
common_d0 = random.randint(min_val, max_val)
common_d2 = random.randint(min_val, max_val)
d1_1 = random.randint(min_val, max_val)
d1_2 = random.randint(min_val, max_val)
shape1 = (common_d0, d1_1, common_d2)
shape2 = (common_d0, d1_2, common_d2)
else:
common_d0 = random.randint(min_val, max_val)
common_d1 = random.randint(min_val, max_val)
d2_1 = random.randint(min_val, max_val)
d2_2 = random.randint(min_val, max_val)
shape1 = (common_d0, common_d1, d2_1)
shape2 = (common_d0, common_d1, d2_2)
return shape1, shape2, dim
def gen_100_cat_shapes(
num_groups=100,
mix_ratio=(0.3, 0.3, 0.4),
min_val=1,
max_val=4096
):
shape_list = []
num_1d = int(num_groups * mix_ratio[0])
num_2d = int(num_groups * mix_ratio[1])
num_3d = num_groups - num_1d - num_2d
for _ in range(num_1d):
shape_list.append(gen_1d_cat_shapes(min_val, max_val))
for _ in range(num_2d):
dim = random.choice([0, 1])
shape_list.append(gen_2d_cat_shapes(dim, min_val, max_val))
for _ in range(num_3d):
dim = random.choice([0, 1, 2])
shape_list.append(gen_3d_cat_shapes(dim, min_val, max_val))
random.shuffle(shape_list)
return shape_list
full_shape = gen_100_cat_shapes(
num_groups=100,
mix_ratio=(0.3, 0.4, 0.3),
min_val=1,
max_val=4096
)
@triton.jit
def _cat_helper_func_2D_1(
in_ptr0,
in_ptr1,
out_ptr0,
in0_x: tl.constexpr,
in1_x: tl.constexpr,
y0_numel,
x1_numel,
Y0BLOCK: tl.constexpr,
Y0BLOCK_SUB: tl.constexpr,
):
y0_offset = tl.program_id(0) * Y0BLOCK_SUB
base_y0 = tl.arange(0, Y0BLOCK_SUB)
loops_y0 = (Y0BLOCK + Y0BLOCK_SUB - 1) // Y0BLOCK_SUB
base_input0_x1 = tl.arange(0, in0_x)[None, :]
base_input1_x1 = tl.arange(0, in1_x)[None, :]
x1 = tl.arange(0, in0_x + in1_x)[None, :]
for loop in range(loops_y0):
y0 = y0_offset + (loop * Y0BLOCK_SUB) + base_y0[:, None]
y0_mask = y0 < min(Y0BLOCK + y0_offset, y0_numel)
x1_mask = x1 < x1_numel
tmp0 = tl.load(in_ptr0 + (base_input0_x1 + in0_x * y0), y0_mask)
tmp1 = tl.load(in_ptr1 + (base_input1_x1 + in1_x * y0), y0_mask)
tmp2 = tl.zeros((Y0BLOCK_SUB, in0_x + in1_x), dtype=tmp0.dtype)
tmp3 = extension.insert_slice(tmp2, tmp0, [0, 0], [Y0BLOCK_SUB, in0_x], [1, 1])
tmp4 = extension.insert_slice(tmp3, tmp1, [0, in0_x], [Y0BLOCK_SUB, in1_x], [1, 1])
tl.store(out_ptr0 + (x1 + (in0_x + in1_x) * y0), tmp4, x1_mask & y0_mask)
@triton.jit
def triton_unk_fused_cat_dim0_sameshape(output_ptr, x_ptr, y_ptr, y0_numel, x1_numel, Y0BLOCK: tl.constexpr, Y0BLOCK_SUB: tl.constexpr, X1BLOCK_SUB: tl.constexpr):
y0_offset = tl.program_id(0) * Y0BLOCK
base_y0 = tl.arange(0, Y0BLOCK_SUB)
loops_y0 = (Y0BLOCK + Y0BLOCK_SUB - 1) // Y0BLOCK_SUB
base_x1 = tl.arange(0, X1BLOCK_SUB)
loops_x1 = (x1_numel + X1BLOCK_SUB - 1) // X1BLOCK_SUB
for loop_y0 in range(loops_y0):
y0 = y0_offset + (loop_y0 * Y0BLOCK_SUB) + base_y0[:, None]
y0_mask = y0 < min(Y0BLOCK + y0_offset, y0_numel)
for loop_x1 in range(loops_x1):
x1 = (loop_x1 * X1BLOCK_SUB) + base_x1[None, :]
x1_mask = x1 < x1_numel
tmp0 = tl.load(x_ptr + (x1 + x1_numel * y0), x1_mask & y0_mask)
tmp8 = tl.load(y_ptr + (x1 + x1_numel * y0), x1_mask & y0_mask)
tmp10 = tl.zeros((2 * Y0BLOCK_SUB, X1BLOCK_SUB), dtype=tmp0.dtype)
tmp11 = extension.insert_slice(tmp10, tmp0, [0, 0], [Y0BLOCK_SUB, X1BLOCK_SUB], [1, 1])
tmp12 = extension.insert_slice(tmp11, tmp8, [Y0BLOCK_SUB, 0], [Y0BLOCK_SUB, X1BLOCK_SUB], [1, 1])
tmp13 = tl.reshape(tmp12, (2, Y0BLOCK_SUB, X1BLOCK_SUB))
new_base_x2 = tl.arange(0, X1BLOCK_SUB)
new_x2 = (loop_x1 * X1BLOCK_SUB) + new_base_x2[None, None, :]
new_base_y1 = tl.arange(0, Y0BLOCK_SUB)
new_y1 = y0_offset + (loop_y0 * Y0BLOCK_SUB) + new_base_y1[None, :, None]
new_z0 = tl.arange(0, 2)[:, None, None]
new_x2_mask = new_x2 < x1_numel
new_y1_mask = new_y1 < y0_numel
tl.store(output_ptr + (new_x2 + x1_numel * (new_y1 + y0_numel * new_z0)), tmp13, new_x2_mask & new_y1_mask)
@triton.jit
def triton_unk_fused_cat_dim0_diffshape(output_ptr, x_ptr, y_ptr, y0_numel, y1_numel, x1_numel, YBLOCK: tl.constexpr,
YBLOCK_2: tl.constexpr, YBLOCK_SUB: tl.constexpr, X1BLOCK_SUB: tl.constexpr):
y0_offset = tl.program_id(0) * YBLOCK
base_y0 = tl.arange(0, YBLOCK_SUB)
loops_y0 = (YBLOCK + YBLOCK_SUB - 1) // YBLOCK_SUB
base_x1 = tl.arange(0, X1BLOCK_SUB)
loops_x1 = (x1_numel + X1BLOCK_SUB - 1) // X1BLOCK_SUB
min_numel = 0
max_numel = 0
clone_numel = 0
if y0_numel < y1_numel:
min_numel = y0_numel
max_numel = y1_numel
clone_numel = y1_numel - y0_numel
else:
min_numel = y1_numel
max_numel = y0_numel
clone_numel = y0_numel - y1_numel
for loops_y in range(loops_y0):
y0 = y0_offset + (loops_y * YBLOCK_SUB) + base_y0[:, None]
y0_mask = y0 < min(YBLOCK + y0_offset, min_numel)
for loop_x1 in range(loops_x1):
x1 = (loop_x1 * X1BLOCK_SUB) + base_x1[None, :]
x1_mask = x1 < x1_numel
tmp0 = tl.load(x_ptr + (x1 + x1_numel * y0), x1_mask & y0_mask)
tmp8 = tl.load(y_ptr + (x1 + x1_numel * y0), x1_mask & y0_mask)
tmp10 = tl.zeros((2 * YBLOCK_SUB, X1BLOCK_SUB), dtype=tmp0.dtype)
tmp11 = extension.insert_slice(tmp10, tmp0, [0, 0], [YBLOCK_SUB, X1BLOCK_SUB], [1, 1])
tmp12 = extension.insert_slice(tmp11, tmp8, [YBLOCK_SUB, 0], [YBLOCK_SUB, X1BLOCK_SUB], [1, 1])
tmp13 = tl.reshape(tmp12, (2, YBLOCK_SUB, X1BLOCK_SUB))
new_base_x2 = tl.arange(0, X1BLOCK_SUB)
new_x2 = (loop_x1 * X1BLOCK_SUB) + new_base_x2[None, None, :]
new_base_y1 = tl.arange(0, YBLOCK_SUB)
new_y1 = y0_offset + (loops_y * YBLOCK_SUB) + new_base_y1[None, :, None]
new_z0 = tl.arange(0, 2)[:, None, None]
new_x2_mask = new_x2 < x1_numel
new_y1_mask = new_y1 < min_numel
tl.store(output_ptr + (new_x2 + x1_numel * new_y1 + x1_numel * y0_numel * new_z0), tmp13, new_x2_mask & new_y1_mask)
loops_y1 = (YBLOCK_2 + YBLOCK_SUB - 1) // YBLOCK_SUB
y2_offset = tl.program_id(0) * YBLOCK_2 + min_numel
if y0_numel < y1_numel:
for loops_y1 in range(loops_y1):
y0 = y2_offset + (loops_y1 * YBLOCK_SUB) + base_y0[:, None]
y0_mask = y0 < min(YBLOCK_2 + y2_offset, y1_numel)
for loop_x1 in range(loops_x1):
x1 = (loop_x1 * X1BLOCK_SUB) + base_x1[None, :]
x1_mask = x1 < x1_numel
tmp8 = tl.load(y_ptr + (x1 + x1_numel * y0), x1_mask & y0_mask)
new_base_x2 = tl.arange(0, X1BLOCK_SUB)
new_x2 = (loop_x1 * X1BLOCK_SUB) + new_base_x2[None, :]
new_base_y1 = tl.arange(0, YBLOCK_SUB)
new_y1 = y2_offset + y0_numel + (loops_y1 * YBLOCK_SUB) + new_base_y1[:, None]
sum_numel = y0_numel + y1_numel
new_x2_mask = new_x2 < x1_numel
new_y1_mask = new_y1 < sum_numel
tl.store(output_ptr + (new_x2 + x1_numel * new_y1), tmp8, new_x2_mask & new_y1_mask)
else:
for loops_y1 in range(loops_y1):
y0 = y2_offset + (loops_y1 * YBLOCK_SUB) + base_y0[:, None]
y0_mask = y0 < min(YBLOCK_2 + y2_offset, y0_numel)
for loop_x1 in range(loops_x1):
x1 = (loop_x1 * X1BLOCK_SUB) + base_x1[None, :]
x1_mask = x1 < x1_numel
tmp8 = tl.load(x_ptr + (x1 + x1_numel * y0), x1_mask & y0_mask)
new_base_x2 = tl.arange(0, X1BLOCK_SUB)
new_x2 = (loop_x1 * X1BLOCK_SUB) + new_base_x2[None, :]
new_base_y1 = tl.arange(0, YBLOCK_SUB)
new_y1 = y2_offset + (loops_y1 * YBLOCK_SUB) + new_base_y1[:, None]
new_x2_mask = new_x2 < x1_numel
new_y1_mask = new_y1 < y0_numel
tl.store(output_ptr + (new_x2 + x1_numel * new_y1), tmp8, new_x2_mask & new_y1_mask)
@triton.jit
def triton_unk_fused_cat_dim1_sameshape(output_ptr, x_ptr, y_ptr, y0_numel, x1_numel, Y0BLOCK: tl.constexpr, Y0BLOCK_SUB: tl.constexpr, X1BLOCK_SUB: tl.constexpr):
y0_offset = tl.program_id(0) * Y0BLOCK
base_y0 = tl.arange(0, Y0BLOCK_SUB)
loops_y0 = (Y0BLOCK + Y0BLOCK_SUB - 1) // Y0BLOCK_SUB
base_x1 = tl.arange(0, X1BLOCK_SUB)
loops_x1 = (x1_numel + X1BLOCK_SUB - 1) // X1BLOCK_SUB
for loop_y0 in range(loops_y0):
y0 = y0_offset + (loop_y0 * Y0BLOCK_SUB) + base_y0[:, None]
y0_mask = y0 < min(Y0BLOCK + y0_offset, y0_numel)
for loop_x1 in range(loops_x1):
x1 = (loop_x1 * X1BLOCK_SUB) + base_x1[None, :]
x1_mask = x1 < x1_numel
tmp0 = tl.load(x_ptr + (x1 + x1_numel * y0), x1_mask & y0_mask)
tmp8 = tl.load(y_ptr + (x1 + x1_numel * y0), x1_mask & y0_mask)
tmp10 = tl.zeros((Y0BLOCK_SUB, 2 * X1BLOCK_SUB), dtype=tmp0.dtype)
tmp11 = extension.insert_slice(tmp10, tmp0, [0, 0], [Y0BLOCK_SUB, X1BLOCK_SUB], [1, 1])
tmp12 = extension.insert_slice(tmp11, tmp8, [0, X1BLOCK_SUB], [Y0BLOCK_SUB, X1BLOCK_SUB], [1, 1])
tmp13 = tl.reshape(tmp12, (Y0BLOCK_SUB, 2, X1BLOCK_SUB))
new_base_x2 = tl.arange(0, X1BLOCK_SUB)
new_x2 = (loop_x1 * X1BLOCK_SUB) + new_base_x2[None, None, :]
new_base_y1 = tl.arange(0, Y0BLOCK_SUB)
new_y1 = y0_offset + (loop_y0 * Y0BLOCK_SUB) + new_base_y1[:, None, None]
new_z0 = tl.arange(0, 2)[None, :, None]
new_x2_mask = new_x2 < x1_numel
new_y1_mask = new_y1 < y0_numel
tl.store(output_ptr + (new_x2 + 2 * x1_numel * new_y1 + x1_numel * new_z0), tmp13, new_x2_mask & new_y1_mask)
@triton.jit
def triton_unk_fused_cat_dim1_diffshape(output_ptr, x_ptr, y_ptr, y0_numel, x0_numel, x1_numel, Y0BLOCK: tl.constexpr, Y0BLOCK_SUB: tl.constexpr, XBLOCK_SUB: tl.constexpr):
y0_offset = tl.program_id(0) * Y0BLOCK
base_y0 = tl.arange(0, Y0BLOCK_SUB)
loops_y0 = (Y0BLOCK + Y0BLOCK_SUB - 1) // Y0BLOCK_SUB
base_x = tl.arange(0, XBLOCK_SUB)
min_numel = 0
max_numel = 0
clone_numel = 0
if x0_numel < x1_numel:
min_numel = x0_numel
max_numel = x1_numel
clone_numel = x1_numel - x0_numel
else:
min_numel = x1_numel
max_numel = x0_numel
clone_numel = x0_numel - x1_numel
loops_x = (min_numel + XBLOCK_SUB - 1) // XBLOCK_SUB
loops_x2 = (clone_numel + XBLOCK_SUB - 1) // XBLOCK_SUB
for loop_y0 in range(loops_y0):
y0 = y0_offset + (loop_y0 * Y0BLOCK_SUB) + base_y0[:, None]
y0_mask = y0 < min(Y0BLOCK + y0_offset, y0_numel)
for loop_x in range(loops_x):
x = (loop_x * XBLOCK_SUB) + base_x[None, :]
x_mask = x < min_numel
tmp0 = tl.load(x_ptr + (x + x0_numel * y0), x_mask & y0_mask)
tmp8 = tl.load(y_ptr + (x + x1_numel * y0), x_mask & y0_mask)
tmp10 = tl.zeros((Y0BLOCK_SUB, 2 * XBLOCK_SUB), dtype=tmp0.dtype)
tmp11 = extension.insert_slice(tmp10, tmp0, [0, 0], [Y0BLOCK_SUB, XBLOCK_SUB], [1, 1])
tmp12 = extension.insert_slice(tmp11, tmp8, [0, XBLOCK_SUB], [Y0BLOCK_SUB, XBLOCK_SUB], [1, 1])
tmp13 = tl.reshape(tmp12, (Y0BLOCK_SUB, 2, XBLOCK_SUB))
new_base_x2 = tl.arange(0, XBLOCK_SUB)
new_x2 = (loop_x * XBLOCK_SUB) + new_base_x2[None, None, :]
new_base_y1 = tl.arange(0, Y0BLOCK_SUB)
new_y1 = y0_offset + (loop_y0 * Y0BLOCK_SUB) + new_base_y1[:, None, None]
new_z0 = tl.arange(0, 2)[None, :, None]
new_x2_mask = new_x2 < min_numel
new_y1_mask = new_y1 < y0_numel
sum_numel = x0_numel + x1_numel
tl.store(output_ptr + (new_x2 + sum_numel * new_y1 + x0_numel * new_z0), tmp13, new_x2_mask & new_y1_mask)
if x0_numel < x1_numel:
for loop_y0 in range(loops_y0):
y0 = y0_offset + (loop_y0 * Y0BLOCK_SUB) + base_y0[:, None]
y0_mask = y0 < min(Y0BLOCK + y0_offset, y0_numel)
for loop_x2 in range(loops_x2):
x = (loop_x2 * XBLOCK_SUB) + base_x[None, :] + min_numel
x_mask = x < x1_numel
tmp8 = tl.load(y_ptr + (x + x1_numel * y0), x_mask & y0_mask)
new_base_x2 = tl.arange(0, XBLOCK_SUB)
new_x2 = x0_numel + min_numel + (loop_x2 * XBLOCK_SUB) + new_base_x2[None, :]
new_base_y1 = tl.arange(0, Y0BLOCK_SUB)
new_y1 = y0_offset + (loop_y0 * Y0BLOCK_SUB) + new_base_y1[:, None]
sum_numel = x0_numel + x1_numel
new_x2_mask = new_x2 < sum_numel
new_y1_mask = new_y1 < y0_numel
tl.store(output_ptr + (new_x2 + sum_numel * new_y1), tmp8, new_x2_mask & new_y1_mask)
else:
for loop_y0 in range(loops_y0):
y0 = y0_offset + (loop_y0 * Y0BLOCK_SUB) + base_y0[:, None]
y0_mask = y0 < min(Y0BLOCK + y0_offset, y0_numel)
for loop_x2 in range(loops_x2):
x = (loop_x2 * XBLOCK_SUB) + base_x[None, :] + min_numel
x_mask = x < x0_numel
tmp8 = tl.load(x_ptr + (x + x0_numel * y0), x_mask & y0_mask)
new_base_x2 = tl.arange(0, XBLOCK_SUB)
new_x2 = min_numel + (loop_x2 * XBLOCK_SUB) + new_base_x2[None, :]
new_base_y1 = tl.arange(0, Y0BLOCK_SUB)
new_y1 = y0_offset + (loop_y0 * Y0BLOCK_SUB) + new_base_y1[:, None]
sum_numel = x0_numel + x1_numel
new_x2_mask = new_x2 < x0_numel
new_y1_mask = new_y1 < y0_numel
tl.store(output_ptr + (new_x2 + sum_numel * new_y1), tmp8, new_x2_mask & new_y1_mask)
@triton.jit
def triton_unk_fused_cat_3d_dim0(output_ptr, x_ptr, y_ptr, z0_numel, z1_numel, y1_numel, x1_numel, ZBLOCK: tl.constexpr, ZBLOCK_2: tl.constexpr, ZBLOCK_SUB: tl.constexpr, X1BLOCK_SUB: tl.constexpr):
z0_offset = tl.program_id(0) * ZBLOCK
base_z0 = tl.arange(0, ZBLOCK_SUB)
loops_z0 = (ZBLOCK + ZBLOCK_SUB - 1) // ZBLOCK_SUB
xy_numel = x1_numel * y1_numel
base_x1 = tl.arange(0, X1BLOCK_SUB)
loops_x1 = (xy_numel + X1BLOCK_SUB - 1) // X1BLOCK_SUB
min_numel = 0
max_numel = 0
clone_numel = 0
if z0_numel < z1_numel:
min_numel = z0_numel
max_numel = z1_numel
clone_numel = z1_numel - z0_numel
else:
min_numel = z1_numel
max_numel = z0_numel
clone_numel = z0_numel - z1_numel
for loops_z in range(loops_z0):
z0 = z0_offset + (loops_z * ZBLOCK_SUB) + base_z0[:, None]
z0_mask = z0 < min(ZBLOCK + z0_offset, min_numel)
for loop_x1 in range(loops_x1):
x1 = (loop_x1 * X1BLOCK_SUB) + base_x1[None, :]
x1_mask = x1 < xy_numel
tmp0 = tl.load(x_ptr + (x1 + xy_numel * z0), x1_mask & z0_mask)
tmp8 = tl.load(y_ptr + (x1 + xy_numel * z0), x1_mask & z0_mask)
tmp10 = tl.zeros((2 * ZBLOCK_SUB, X1BLOCK_SUB), dtype=tmp0.dtype)
tmp11 = extension.insert_slice(tmp10, tmp0, [0, 0], [ZBLOCK_SUB, X1BLOCK_SUB], [1, 1])
tmp12 = extension.insert_slice(tmp11, tmp8, [ZBLOCK_SUB, 0], [ZBLOCK_SUB, X1BLOCK_SUB], [1, 1])
tmp13 = tl.reshape(tmp12, (2, ZBLOCK_SUB, X1BLOCK_SUB))
new_base_x2 = tl.arange(0, X1BLOCK_SUB)
new_x2 = (loop_x1 * X1BLOCK_SUB) + new_base_x2[None, None, :]
new_base_z1 = tl.arange(0, ZBLOCK_SUB)
new_z1 = z0_offset + (loops_z * ZBLOCK_SUB) + new_base_z1[None, :, None]
new_z0 = tl.arange(0, 2)[:, None, None]
new_x2_mask = new_x2 < xy_numel
new_z1_mask = new_z1 < min_numel
tl.store(output_ptr + (new_x2 + xy_numel * new_z1 + xy_numel * z0_numel * new_z0), tmp13, new_x2_mask & new_z1_mask)
loops_z1 = (ZBLOCK_2 + ZBLOCK_SUB - 1) // ZBLOCK_SUB
z2_offset = tl.program_id(0) * ZBLOCK_2 + min_numel
if z0_numel < z1_numel:
for loops_z1 in range(loops_z1):
z0 = z2_offset + (loops_z1 * ZBLOCK_SUB) + base_z0[:, None]
z0_mask = z0 < min(ZBLOCK_2 + z2_offset, z1_numel)
for loop_x1 in range(loops_x1):
x1 = (loop_x1 * X1BLOCK_SUB) + base_x1[None, :]
x1_mask = x1 < xy_numel
tmp8 = tl.load(y_ptr + (x1 + xy_numel * z0), x1_mask & z0_mask)
new_base_x2 = tl.arange(0, X1BLOCK_SUB)
new_x2 = (loop_x1 * X1BLOCK_SUB) + new_base_x2[None, :]
new_base_z1 = tl.arange(0, ZBLOCK_SUB)
new_z1 = z2_offset + z0_numel + (loops_z1 * ZBLOCK_SUB) + new_base_z1[:, None]
sum_numel = z0_numel + z1_numel
new_x2_mask = new_x2 < xy_numel
new_z1_mask = new_z1 < sum_numel
tl.store(output_ptr + (new_x2 + xy_numel * new_z1), tmp8, new_x2_mask & new_z1_mask)
else:
for loops_z1 in range(loops_z1):
z0 = z2_offset + (loops_z1 * ZBLOCK_SUB) + base_z0[:, None]
z0_mask = z0 < min(ZBLOCK_2 + z2_offset, z0_numel)
for loop_x1 in range(loops_x1):
x1 = (loop_x1 * X1BLOCK_SUB) + base_x1[None, :]
x1_mask = x1 < xy_numel
tmp8 = tl.load(x_ptr + (x1 + xy_numel * z0), x1_mask & z0_mask)
new_base_x2 = tl.arange(0, X1BLOCK_SUB)
new_x2 = (loop_x1 * X1BLOCK_SUB) + new_base_x2[None, :]
new_base_z1 = tl.arange(0, ZBLOCK_SUB)
new_z1 = z2_offset + (loops_z1 * ZBLOCK_SUB) + new_base_z1[:, None]
new_x2_mask = new_x2 < xy_numel
new_z1_mask = new_z1 < z0_numel
tl.store(output_ptr + (new_x2 + xy_numel * new_z1), tmp8, new_x2_mask & new_z1_mask)
@triton.jit
def triton_unk_fused_cat_3d_dim1(output_ptr, x_ptr, y_ptr, z0_numel, y0_numel, y1_numel, x0_numel, Z0BLOCK: tl.constexpr, Z0BLOCK_SUB: tl.constexpr, XBLOCK_SUB: tl.constexpr):
z0_offset = tl.program_id(0) * Z0BLOCK
base_z0 = tl.arange(0, Z0BLOCK_SUB)
loops_z0 = (Z0BLOCK + Z0BLOCK_SUB - 1) // Z0BLOCK_SUB
base_x = tl.arange(0, XBLOCK_SUB)
min_numel = 0
max_numel = 0
clone_numel = 0
if y0_numel < y1_numel:
min_numel = y0_numel * x0_numel
max_numel = y1_numel * x0_numel
clone_numel = (y1_numel - y0_numel) * x0_numel
else:
min_numel = y1_numel * x0_numel
max_numel = y0_numel * x0_numel
clone_numel = (y0_numel - y1_numel) * x0_numel
loops_x = (min_numel + XBLOCK_SUB - 1) // XBLOCK_SUB
loops_x2 = (clone_numel + XBLOCK_SUB - 1) // XBLOCK_SUB
for loop_z0 in range(loops_z0):
z0 = z0_offset + (loop_z0 * Z0BLOCK_SUB) + base_z0[:, None]
z0_mask = z0 < min(Z0BLOCK + z0_offset, z0_numel)
for loop_x in range(loops_x):
x = (loop_x * XBLOCK_SUB) + base_x[None, :]
x_mask = x < min_numel
tmp0 = tl.load(x_ptr + (x + x0_numel * y0_numel * z0), x_mask & z0_mask)
tmp8 = tl.load(y_ptr + (x + x0_numel * y1_numel * z0), x_mask & z0_mask)
tmp10 = tl.zeros((Z0BLOCK_SUB, 2 * XBLOCK_SUB), dtype=tmp0.dtype)
tmp11 = extension.insert_slice(tmp10, tmp0, [0, 0], [Z0BLOCK_SUB, XBLOCK_SUB], [1, 1])
tmp12 = extension.insert_slice(tmp11, tmp8, [0, XBLOCK_SUB], [Z0BLOCK_SUB, XBLOCK_SUB], [1, 1])
tmp13 = tl.reshape(tmp12, (Z0BLOCK_SUB, 2, XBLOCK_SUB))
new_base_x2 = tl.arange(0, XBLOCK_SUB)
new_x2 = (loop_x * XBLOCK_SUB) + new_base_x2[None, None, :]
new_base_z1 = tl.arange(0, Z0BLOCK_SUB)
new_z1 = z0_offset + (loop_z0 * Z0BLOCK_SUB) + new_base_z1[:, None, None]
new_z0 = tl.arange(0, 2)[None, :, None]
new_x2_mask = new_x2 < min_numel
new_z1_mask = new_z1 < z0_numel
sum_numel = min_numel + max_numel
tl.store(output_ptr + (new_x2 + sum_numel * new_z1 + x0_numel * y0_numel * new_z0), tmp13, new_x2_mask & new_z1_mask)
if y0_numel == y1_numel:
return
if y0_numel < y1_numel:
for loop_z0 in range(loops_z0):
z0 = z0_offset + (loop_z0 * Z0BLOCK_SUB) + base_z0[:, None]
z0_mask = z0 < min(Z0BLOCK + z0_offset, z0_numel)
for loop_x2 in range(loops_x2):
x = (loop_x2 * XBLOCK_SUB) + base_x[None, :] + min_numel
x_mask = x < y1_numel * x0_numel
tmp8 = tl.load(y_ptr + (x + x0_numel * y1_numel * z0), x_mask & z0_mask)
new_base_x2 = tl.arange(0, XBLOCK_SUB)
new_x2 = x0_numel * y0_numel + min_numel + (loop_x2 * XBLOCK_SUB) + new_base_x2[None, :]
new_base_z1 = tl.arange(0, Z0BLOCK_SUB)
new_z1 = z0_offset + (loop_z0 * Z0BLOCK_SUB) + new_base_z1[:, None]
sum_numel = min_numel + max_numel
new_x2_mask = new_x2 < sum_numel
new_z1_mask = new_z1 < z0_numel
tl.store(output_ptr + (new_x2 + sum_numel * new_z1), tmp8, new_x2_mask & new_z1_mask)
else:
for loop_z0 in range(loops_z0):
z0 = z0_offset + (loop_z0 * Z0BLOCK_SUB) + base_z0[:, None]
z0_mask = z0 < min(Z0BLOCK + z0_offset, z0_numel)
for loop_x2 in range(loops_x2):
x = (loop_x2 * XBLOCK_SUB) + base_x[None, :] + min_numel
x_mask = x < x0_numel * y0_numel
tmp8 = tl.load(x_ptr + (x + x0_numel * y0_numel * z0), x_mask & z0_mask)
new_base_x2 = tl.arange(0, XBLOCK_SUB)
new_x2 = min_numel + (loop_x2 * XBLOCK_SUB) + new_base_x2[None, :]
new_base_z1 = tl.arange(0, Z0BLOCK_SUB)
new_z1 = z0_offset + (loop_z0 * Z0BLOCK_SUB) + new_base_z1[:, None]
sum_numel = min_numel + max_numel
new_x2_mask = new_x2 < x0_numel * y0_numel
new_z1_mask = new_z1 < z0_numel
tl.store(output_ptr + (new_x2 + sum_numel * new_z1), tmp8, new_x2_mask & new_z1_mask)
@triton.jit
def triton_unk_fused_cat_3d_dim2(output_ptr, x_ptr, y_ptr, z0_numel, y0_numel, x0_numel, x1_numel, Y0BLOCK: tl.constexpr, Y0BLOCK_SUB: tl.constexpr, XBLOCK_SUB: tl.constexpr):
y0_offset = tl.program_id(0) * Y0BLOCK
base_y0 = tl.arange(0, Y0BLOCK_SUB)
loops_y0 = (Y0BLOCK + Y0BLOCK_SUB - 1) // Y0BLOCK_SUB
base_x = tl.arange(0, XBLOCK_SUB)
min_numel = 0
max_numel = 0
clone_numel = 0
zy_numel = z0_numel * y0_numel
if x0_numel < x1_numel:
min_numel = x0_numel
max_numel = x1_numel
clone_numel = x1_numel - x0_numel
else:
min_numel = x1_numel
max_numel = x0_numel
clone_numel = x0_numel - x1_numel
loops_x = (min_numel + XBLOCK_SUB - 1) // XBLOCK_SUB
loops_x2 = (clone_numel + XBLOCK_SUB - 1) // XBLOCK_SUB
for loop_y0 in range(loops_y0):
y0 = y0_offset + (loop_y0 * Y0BLOCK_SUB) + base_y0[:, None]
y0_mask = y0 < min(Y0BLOCK + y0_offset, zy_numel)
for loop_x in range(loops_x):
x = (loop_x * XBLOCK_SUB) + base_x[None, :]
x_mask = x < min_numel
tmp0 = tl.load(x_ptr + (x + x0_numel * y0), x_mask & y0_mask)
tmp8 = tl.load(y_ptr + (x + x1_numel * y0), x_mask & y0_mask)
tmp10 = tl.zeros((Y0BLOCK_SUB, 2 * XBLOCK_SUB), dtype=tmp0.dtype)
tmp11 = extension.insert_slice(tmp10, tmp0, [0, 0], [Y0BLOCK_SUB, XBLOCK_SUB], [1, 1])
tmp12 = extension.insert_slice(tmp11, tmp8, [0, XBLOCK_SUB], [Y0BLOCK_SUB, XBLOCK_SUB], [1, 1])
tmp13 = tl.reshape(tmp12, (Y0BLOCK_SUB, 2, XBLOCK_SUB))
new_base_x2 = tl.arange(0, XBLOCK_SUB)
new_x2 = (loop_x * XBLOCK_SUB) + new_base_x2[None, None, :]
new_base_y1 = tl.arange(0, Y0BLOCK_SUB)
new_y1 = y0_offset + (loop_y0 * Y0BLOCK_SUB) + new_base_y1[:, None, None]
new_z0 = tl.arange(0, 2)[None, :, None]
new_x2_mask = new_x2 < min_numel
new_y1_mask = new_y1 < zy_numel
sum_numel = x0_numel + x1_numel
tl.store(output_ptr + (new_x2 + sum_numel * new_y1 + x0_numel * new_z0), tmp13, new_x2_mask & new_y1_mask)
if x0_numel == x1_numel:
return
if x0_numel < x1_numel:
for loop_y0 in range(loops_y0):
y0 = y0_offset + (loop_y0 * Y0BLOCK_SUB) + base_y0[:, None]
y0_mask = y0 < min(Y0BLOCK + y0_offset, zy_numel)
for loop_x2 in range(loops_x2):
x = (loop_x2 * XBLOCK_SUB) + base_x[None, :] + min_numel
x_mask = x < x1_numel
tmp8 = tl.load(y_ptr + (x + x1_numel * y0), x_mask & y0_mask)
new_base_x2 = tl.arange(0, XBLOCK_SUB)
new_x2 = x0_numel + min_numel + (loop_x2 * XBLOCK_SUB) + new_base_x2[None, :]
new_base_y1 = tl.arange(0, Y0BLOCK_SUB)
new_y1 = y0_offset + (loop_y0 * Y0BLOCK_SUB) + new_base_y1[:, None]
sum_numel = x0_numel + x1_numel
new_x2_mask = new_x2 < sum_numel
new_y1_mask = new_y1 < zy_numel
tl.store(output_ptr + (new_x2 + sum_numel * new_y1), tmp8, new_x2_mask & new_y1_mask)
else:
for loop_y0 in range(loops_y0):
y0 = y0_offset + (loop_y0 * Y0BLOCK_SUB) + base_y0[:, None]
y0_mask = y0 < min(Y0BLOCK + y0_offset, zy_numel)
for loop_x2 in range(loops_x2):
x = (loop_x2 * XBLOCK_SUB) + base_x[None, :] + min_numel
x_mask = x < x0_numel
tmp8 = tl.load(x_ptr + (x + x0_numel * y0), x_mask & y0_mask)
new_base_x2 = tl.arange(0, XBLOCK_SUB)
new_x2 = min_numel + (loop_x2 * XBLOCK_SUB) + new_base_x2[None, :]
new_base_y1 = tl.arange(0, Y0BLOCK_SUB)
new_y1 = y0_offset + (loop_y0 * Y0BLOCK_SUB) + new_base_y1[:, None]
sum_numel = x0_numel + x1_numel
new_x2_mask = new_x2 < x0_numel
new_y1_mask = new_y1 < zy_numel
tl.store(output_ptr + (new_x2 + sum_numel * new_y1), tmp8, new_x2_mask & new_y1_mask)
testlist = [
((3,), (3,), 0),
((7,), (9,), 0),
((13,), (11,), 0),
((2047,), (2047,), 0),
((2701,), (3003,), 0),
((4093,), (3095,), 0),
((3, 5), (3, 5), 0),
((1005, 300), (2007, 300), 0),
((1307, 400), (309, 400), 0),
((303, 500), (303, 500), 0),
((7, 9), (7, 9), 1),
((100, 1001), (100, 2003), 1),
((200, 2005), (200, 207), 1),
((300, 707), (300, 707), 1),
((378, 200, 300), (101, 200, 300), 0),
((378, 70, 50), (601, 70, 50), 0),
((100, 452, 300), (100, 201, 300), 1),
((65, 1735, 57), (65, 2001, 57), 1),
((87, 200, 387), (87, 200, 501), 2),
((20, 337, 543), (20, 337, 401), 2),
]
@pytest.mark.parametrize('testlists', testlist)
@pytest.mark.parametrize('dtype', ['bool', 'int8', 'int16', 'int32', 'int64', 'float16', 'float32', 'bfloat16'])
def test_cat_bigshape(testlists, dtype):
torch_dtype = eval('torch.' + dtype)
np_x0 = test_common.generate_numpy(testlists[0], dtype)
np_x1 = test_common.generate_numpy(testlists[1], dtype)
cat_dim = testlists[2]
x0 = torch.from_numpy(np_x0).to(torch_dtype).npu()
x1 = torch.from_numpy(np_x1).to(torch_dtype).npu()
if len(x0.shape) > 3:
pytest.skip("dim > 3 for 3D+ tensor, skipping.")
torch_res = torch.cat([x0, x1], dim=cat_dim)
triton_res = torch.zeros_like(torch_res)
num_core = 32
if len(x0.shape) == 3:
if cat_dim == 0:
ZBLOCK = (min(x0.shape[0], x1.shape[0]) + num_core - 1) // num_core
ZBLOCK_2 = (max(x0.shape[0], x1.shape[0]) - min(x0.shape[0], x1.shape[0]) + num_core - 1) // num_core
triton_unk_fused_cat_3d_dim0[num_core, 1, 1](triton_res, x0, x1, x0.shape[0], x1.shape[0], x0.shape[1], x0.shape[2], ZBLOCK, ZBLOCK_2, 1, 256)
elif cat_dim == 1:
Z0BLOCK = (x0.shape[0] + num_core - 1) // num_core
triton_unk_fused_cat_3d_dim1[num_core, 1, 1](triton_res, x0, x1, x0.shape[0], x0.shape[1], x1.shape[1], x1.shape[2], Z0BLOCK, 1, 256)
else:
Y0BLOCK = (x0.shape[0] * x0.shape[1] + num_core - 1) // num_core
triton_unk_fused_cat_3d_dim2[num_core, 1, 1](triton_res, x0, x1, x0.shape[0], x0.shape[1], x0.shape[2], x1.shape[2], Y0BLOCK, 1, 256)
test_common.validate_cmp(dtype, torch_res, triton_res)
return
numel_large = torch_res.numel() > 512 and len(x0.shape) < 3
if numel_large or (cat_dim == 0 and len(x0.shape) == 2):
squeeze_flag = False
if len(x0.shape) == 1:
squeeze_flag = True
x0 = torch.unsqueeze(x0, dim=0)
x1 = torch.unsqueeze(x1, dim=0)
triton_res = torch.unsqueeze(triton_res, dim=0)
cat_dim = 1
if cat_dim == 1:
Y0BLOCK = (x0.shape[0] + num_core - 1) // num_core
if x0.shape[1] == x1.shape[1]:
triton_unk_fused_cat_dim1_sameshape[num_core, 1, 1](triton_res, x0, x1, x0.shape[0], x0.shape[1], Y0BLOCK, 1, 256)
else:
triton_unk_fused_cat_dim1_diffshape[num_core, 1, 1](triton_res, x0, x1, x0.shape[0], x0.shape[1], x1.shape[1], Y0BLOCK, 1, 256)
else:
if x0.shape[0] == x1.shape[0]:
Y0BLOCK = (x0.shape[0] + num_core - 1) // num_core
triton_unk_fused_cat_dim0_sameshape[num_core, 1, 1](triton_res, x0, x1, x0.shape[0], x0.shape[1], Y0BLOCK, 1, 256)
else:
YBLOCK = (min(x0.shape[0], x1.shape[0]) + num_core - 1) // num_core
YBLOCK_2 = (max(x0.shape[0], x1.shape[0]) - min(x0.shape[0], x1.shape[0]) + num_core - 1) // num_core
triton_unk_fused_cat_dim0_diffshape[num_core, 1, 1](triton_res, x0, x1, x0.shape[0], x1.shape[0], x1.shape[1], YBLOCK, YBLOCK_2, 1, 256)
if squeeze_flag:
triton_res = triton_res.squeeze()
else:
squeeze_flag = False
if len(x0.shape) == 1:
squeeze_flag = True
x0 = torch.unsqueeze(x0, dim=0)
x1 = torch.unsqueeze(x1, dim=0)
triton_res = torch.unsqueeze(triton_res, dim=0)
_cat_helper_func_2D_1[num_core, 1, 1](x0, x1, triton_res, x0.shape[1], x1.shape[1], x0.shape[0], x0.shape[1] + x1.shape[1], 256, 16)
if squeeze_flag:
triton_res = triton_res.squeeze()
test_common.validate_cmp(dtype, torch_res, triton_res)