import torch
import triton
import triton.language as tl
@triton.jit
def atomic_rmw_useanalysis_kernel(
input_ptr,
output_ptr,
m_ptr,
d_ptr,
N: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(0)
base_idx = pid * 8
term1 = 15.0 * 15.0
term2 = 8.0 * (7.0 - base_idx)
delta = term1 + term2
sqrt_delta = tl.sqrt(delta)
task_idx = tl.ceil((15.0 - sqrt_delta) / 2.0)
task_idx_i32 = task_idx.to(tl.int32)
block_start = task_idx_i32 * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < N
data = tl.load(input_ptr + offsets, mask=mask, other=0.0)
m_val = tl.load(m_ptr + offsets, mask=mask, other=0.0)
d_val = tl.load(d_ptr + offsets, mask=mask, other=0.0)
scaled = data - m_val
p = tl.exp(scaled)
result = p * (data * 2.0 - d_val)
output_offsets = offsets
tl.atomic_add(output_ptr + output_offsets, result, mask=mask)
def test_atomic_rmw_useanalysis():
DEVICE = "npu"
N = 1024
BLOCK_SIZE = 128
torch.manual_seed(42)
input_data = torch.randn(N, dtype=torch.float32, device=DEVICE)
m_data = torch.randn(N, dtype=torch.float32, device=DEVICE)
d_data = torch.randn(N, dtype=torch.float32, device=DEVICE)
output_data = torch.zeros(N, dtype=torch.float32, device=DEVICE)
grid = (8,)
atomic_rmw_useanalysis_kernel[grid](
input_data,
output_data,
m_data,
d_data,
N=N,
BLOCK_SIZE=BLOCK_SIZE,
)
output_sum = output_data.abs().sum().item()
if output_sum == 0:
raise AssertionError("UseAnalysis bug detected: atomic_rmw dependencies were erased")
else:
print(" AtomicRMW UseAnalysis is working correctly.")