import torch
try:
import triton
import triton.language as tl
TRITON_AVAILABLE = True
except ImportError:
TRITON_AVAILABLE = False
pass
if TRITON_AVAILABLE:
@triton.jit
def _add_kernel(
A, B, C,
M, N,
stride_am, stride_an,
stride_bm, stride_bn,
stride_cm, stride_cn,
BLOCK_SIZE_R: tl.constexpr,
BLOCK_SIZE_C: tl.constexpr,
):
"""
Element-wise addition C = A + B.
- Grid is 1D: each program handles one row-block (BLOCK_SIZE_R rows),
but loops over all column blocks in that row region.
- Supports arbitrary M, N and non-contiguous tensors.
"""
pid = tl.program_id(0)
start_r = pid * BLOCK_SIZE_R
offs_r = start_r + tl.arange(0, BLOCK_SIZE_R)
mask_r = offs_r < M
idx_r = offs_r[:, None]
for start_c in range(0, N, BLOCK_SIZE_C):
offs_c = start_c + tl.arange(0, BLOCK_SIZE_C)
idx_c = offs_c[None, :]
mask = mask_r[:, None] & (idx_c < N)
a_ptrs = A + idx_r * stride_am + idx_c * stride_an
b_ptrs = B + idx_r * stride_bm + idx_c * stride_bn
a = tl.load(a_ptrs, mask=mask, other=0.0)
b = tl.load(b_ptrs, mask=mask, other=0.0)
c = a + b
c_ptrs = C + idx_r * stride_cm + idx_c * stride_cn
tl.store(c_ptrs, c, mask=mask)
def add_fwd(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
if A.shape != B.shape:
raise ValueError(f"input shapes of add_fwd shoule keep same, but got {A.shape} and {B.shape}")
if not A.is_contiguous() or not B.is_contiguous():
raise ValueError(
f"input of add_fwd shoule be contiguous, but got {A.is_contiguous()} and {B.is_contiguous()}"
)
M, N = A.shape
C = torch.empty_like(A)
BLOCK_SIZE_C = min(triton.next_power_of_2(N), 1024)
BLOCK_SIZE_R = min(64, max(1, 8192 // BLOCK_SIZE_C))
num_blocks = triton.cdiv(M, BLOCK_SIZE_R)
grid = (num_blocks,)
_add_kernel[grid](
A, B, C,
M, N,
A.stride(0), A.stride(1),
B.stride(0), B.stride(1),
C.stride(0), C.stride(1),
BLOCK_SIZE_R=BLOCK_SIZE_R,
BLOCK_SIZE_C=BLOCK_SIZE_C,
)
return C