# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.

"""
Autotune
=============
"""
import pytest
import torch
import torch_npu
import triton
import triton.language as tl


# Return a set of different kernel configurations for autotune
def get_autotune_config():
    return [
        triton.Config({'XS': 1 * 128, 'multibuffer': True}),
        triton.Config({'XS': 12 * 1024, 'multibuffer': True}),
        triton.Config({'XS': 12 * 1024, 'multibuffer': False}),
        triton.Config({'XS': 8 * 1024, 'multibuffer': True}),
    ]


# Use @autotune decorator to automatically select the best kernel configuration
@triton.autotune(
    configs=get_autotune_config(),
    key=["numel"],
)
@triton.jit
def triton_calc_kernel(
        out_ptr0, in_ptr0, in_ptr1, numel,
        XS: tl.constexpr  # Block size controlling how many elements each thread block processes
):
    pid = tl.program_id(0)
    idx = pid * XS + tl.arange(0, XS)
    msk = idx < numel
    for i in range(10000):
        tmp0 = tl.load(in_ptr0 + idx, mask=msk, other=0.0)
        tmp1 = tl.load(in_ptr1 + idx, mask=msk, other=0.0)
        tmp2 = tl.math.exp(tmp0) + tmp1 + i
        tl.store(out_ptr0 + idx, tmp2, mask=msk)


# Function to call the Triton kernel with autotuned configuration
def triton_calc_func(x0, x1):
    n = x0.numel()
    y0 = torch.empty_like(x0)

    def grid(meta):
        return (triton.cdiv(n, meta["XS"]), 1, 1)

    triton_calc_kernel[grid](y0, x0, x1, n)
    return y0


# Reference implementation using PyTorch for correctness check
def torch_calc_func(x0, x1):
    return torch.exp(x0) + x1 + 10000 - 1


# ==================== Pytest Test ====================
def test_triton_autotune():
    DEV = "npu"
    DTYPE = torch.float32
    N = 192 * 1024
    x0 = torch.randn((N,), dtype=DTYPE, device=DEV)
    x1 = torch.randn((N,), dtype=DTYPE, device=DEV)

    torch_ref = torch_calc_func(x0, x1)
    triton_cal = triton_calc_func(x0, x1)

    torch.testing.assert_close(triton_cal, torch_ref)