# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
"""
Autotune
=============
"""
import pytest
import torch
import torch_npu
import triton
import triton.language as tl
# Return a set of different kernel configurations for autotune
def get_autotune_config():
return [
triton.Config({'XS': 1 * 128, 'multibuffer': True}),
triton.Config({'XS': 12 * 1024, 'multibuffer': True}),
triton.Config({'XS': 12 * 1024, 'multibuffer': False}),
triton.Config({'XS': 8 * 1024, 'multibuffer': True}),
]
# Use @autotune decorator to automatically select the best kernel configuration
@triton.autotune(
configs=get_autotune_config(),
key=["numel"],
)
@triton.jit
def triton_calc_kernel(
out_ptr0, in_ptr0, in_ptr1, numel,
XS: tl.constexpr # Block size controlling how many elements each thread block processes
):
pid = tl.program_id(0)
idx = pid * XS + tl.arange(0, XS)
msk = idx < numel
for i in range(10000):
tmp0 = tl.load(in_ptr0 + idx, mask=msk, other=0.0)
tmp1 = tl.load(in_ptr1 + idx, mask=msk, other=0.0)
tmp2 = tl.math.exp(tmp0) + tmp1 + i
tl.store(out_ptr0 + idx, tmp2, mask=msk)
# Function to call the Triton kernel with autotuned configuration
def triton_calc_func(x0, x1):
n = x0.numel()
y0 = torch.empty_like(x0)
def grid(meta):
return (triton.cdiv(n, meta["XS"]), 1, 1)
triton_calc_kernel[grid](y0, x0, x1, n)
return y0
# Reference implementation using PyTorch for correctness check
def torch_calc_func(x0, x1):
return torch.exp(x0) + x1 + 10000 - 1
# ==================== Pytest Test ====================
def test_triton_autotune():
DEV = "npu"
DTYPE = torch.float32
N = 192 * 1024
x0 = torch.randn((N,), dtype=DTYPE, device=DEV)
x1 = torch.randn((N,), dtype=DTYPE, device=DEV)
torch_ref = torch_calc_func(x0, x1)
triton_cal = triton_calc_func(x0, x1)
torch.testing.assert_close(triton_cal, torch_ref)