# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.


import math
import triton
import triton.language as tl
import torch
import torch_npu
import pytest
import test_common
from test_common import TestUtils, check_ub_mem_overflow, get_dtype_size

# <<<<<<< test_min_1d
def torch_min(x0, dim, keepdim):
    inp = x0 if x0.device == "cpu" else x0.cpu()
    return torch.min(inp, dim=dim, keepdim=keepdim)[0].npu()

@triton.jit
def triton_min_1d(in_ptr0, out_ptr1, xnumel, XBLOCK : tl.constexpr):
    xoffset = tl.program_id(0) + tl.arange(0, XBLOCK)
    tmp0 = tl.load(in_ptr0 + xoffset, None)
    tmp4 = tl.min(tmp0, 0)
    tl.store(out_ptr1, tmp4, None)

@pytest.mark.parametrize('shape', TestUtils.test_shape1d)
@pytest.mark.parametrize('dtype', TestUtils.full_dtype)
def test_min_1d(dtype, shape):
    if check_ub_mem_overflow(dtype, shape):
        return
    x0 = test_common.generate_tensor(shape, dtype).npu()
    triton_res = torch.empty(1, dtype=eval("torch." + dtype)).npu()
    numel = shape[0]
    triton_min_1d[1,1,1](x0, triton_res, numel, numel)
    torch_res = torch_min(x0, dim=0, keepdim=True)
    test_common.validate_cmp(dtype, triton_res, torch_res)

# >>>>>>> test_min_1d

# <<<<<<< test_min_2d
@triton.jit
def triton_min_2d(in_ptr0, out_ptr0, dim : tl.constexpr, M : tl.constexpr, N : tl.constexpr, MNUMEL: tl.constexpr, NNUMEL: tl.constexpr):
    mblk_idx = tl.arange(0,MNUMEL)
    nblk_idx = tl.arange(0,NNUMEL)
    mmask = mblk_idx < M
    nmask = nblk_idx < N
    mask = (mmask[:,None]) & (nmask[None,:])
    idx = mblk_idx[:,None] * N + nblk_idx[None,:]
    x = tl.load(in_ptr0 + idx, mask = mask, other = -float('inf'))
    tmp4 = tl.min(x, dim)
    if dim == 0:
        tl.store(out_ptr0 + tl.arange(0,N), tmp4, None)
    else:
        tl.store(out_ptr0 + tl.arange(0,M), tmp4, None)

@pytest.mark.parametrize('shape', TestUtils.test_shape2d)
@pytest.mark.parametrize('dtype', TestUtils.full_dtype)
@pytest.mark.parametrize('dim', [0, 1])
def test_min_2d(dtype, shape, dim):
    dtype_size = get_dtype_size(dtype)
    if dtype == 'int8' or dtype == 'bool':
        if dtype_size * math.prod(shape) >= (TestUtils.ub_size / 20):
            pytest.skip(f"dtype:{dtype} shape:{shape} mem overflow")
    elif dtype_size * math.prod(shape) >= (TestUtils.ub_size / 5):
        pytest.skip(f"dtype:{dtype} shape:{shape} mem overflow")
    shapex, shapey = shape
    x0 = test_common.generate_tensor(shape, dtype).npu()
    triton_res = torch.empty([shape[1-dim], ], dtype=eval("torch." + dtype)).npu()
    triton_min_2d[1,1,1](x0, triton_res, dim, shapex, shapey, shapex, shapey)
    torch_res = torch_min(x0, dim=dim, keepdim=False)
    test_common.validate_cmp(dtype, triton_res, torch_res)

# >>>>>>> test_min_2d

# <<<<<<< test_min_3d
def torch_min_3d(x0, no_reduce_dim):
    inp = x0 if x0.device == "cpu" else x0.cpu()
    if no_reduce_dim == 0:
        return torch.min(torch.min(inp, 1)[0], 1)[0].npu()
    elif no_reduce_dim == 1:
        return torch.min(torch.min(inp, 0)[0], 1)[0].npu()
    elif no_reduce_dim == 2:
        return torch.min(torch.min(inp, 0)[0], 0)[0].npu()
    else:
        assert False, f"no reduce dim not right, no_reduce_dim = {no_reduce_dim}"

@triton.jit
def triton_min_3d_0_1(in_ptr, out_ptr,
    xnumel:tl.constexpr, ynumel:tl.constexpr, znumel:tl.constexpr,
    XB:tl.constexpr, YB:tl.constexpr, ZB:tl.constexpr):
    xidx = tl.arange(0,XB)
    yidx = tl.arange(0,YB)
    zidx = tl.arange(0,ZB)
    idx = xidx[:,None,None]*ynumel*znumel + yidx[None,:,None]*znumel + zidx[None,None,:]
    x = tl.load(in_ptr + idx)
    tmp = tl.min(x, 0)
    ret = tl.min(tmp, 0)
    oidx = zidx
    tl.store(out_ptr + oidx, ret)

@triton.jit
def triton_min_3d_0_2(in_ptr, out_ptr,
    xnumel:tl.constexpr, ynumel:tl.constexpr, znumel:tl.constexpr,
    XB:tl.constexpr, YB:tl.constexpr, ZB:tl.constexpr):
    xidx = tl.arange(0,XB)
    yidx = tl.arange(0,YB)
    zidx = tl.arange(0,ZB)
    idx = xidx[:,None,None]*ynumel*znumel + yidx[None,:,None]*znumel + zidx[None,None,:]
    x = tl.load(in_ptr + idx)
    tmp = tl.min(x, 0)
    ret = tl.min(tmp, 1)
    oidx = yidx
    tl.store(out_ptr + oidx, ret)

@triton.jit
def triton_min_3d_1_2(in_ptr, out_ptr,
    xnumel:tl.constexpr, ynumel:tl.constexpr, znumel:tl.constexpr,
    XB:tl.constexpr, YB:tl.constexpr, ZB:tl.constexpr):
    xidx = tl.arange(0,XB)
    yidx = tl.arange(0,YB)
    zidx = tl.arange(0,ZB)
    idx = xidx[:,None,None]*ynumel*znumel + yidx[None,:,None]*znumel + zidx[None,None,:]
    x = tl.load(in_ptr + idx)
    tmp = tl.min(x, 1)
    ret = tl.min(tmp, 1)
    oidx = xidx
    tl.store(out_ptr + oidx, ret)

def triton_min_3d(in_ptr, out_ptr, xnumel, ynumel, znumel, XB, YB, ZB, no_reduce_dim):
    if no_reduce_dim == 0:
        triton_min_3d_1_2[1, 1, 1](in_ptr, out_ptr, xnumel, ynumel, znumel, XB, YB, ZB)
    elif no_reduce_dim == 1:
        triton_min_3d_0_2[1, 1, 1](in_ptr, out_ptr, xnumel, ynumel, znumel, XB, YB, ZB)
    elif no_reduce_dim == 2:
        triton_min_3d_0_1[1, 1, 1](in_ptr, out_ptr, xnumel, ynumel, znumel, XB, YB, ZB)

@pytest.mark.parametrize('shape', TestUtils.test_shape3d)
@pytest.mark.parametrize('dtype', TestUtils.full_dtype)
@pytest.mark.parametrize('no_reduce_dim', [0, 1, 2])
def test_min_3d(dtype, shape, no_reduce_dim):
    x0 = test_common.generate_tensor(shape, dtype).npu()
    triton_res = torch.empty([shape[no_reduce_dim], ], dtype=eval("torch." + dtype)).npu()
    triton_min_3d(x0, triton_res, shape[0], shape[1], shape[2], shape[0], shape[1], shape[2], no_reduce_dim)
    torch_res = torch_min_3d(x0, no_reduce_dim)
    test_common.validate_cmp(dtype, triton_res, torch_res)

# >>>>>>> test_min_3d


# <<<<<<< test_min_4d
def torch_min_4d(x0, dim):
    x0 = x0 if x0.device == "cpu" else x0.cpu()
    if x0.dtype in (torch.int8, torch.int16, torch.int32):
        x0 = x0.to(torch.int64)
    return torch.min(x0, dim=dim)[0]


@triton.jit
def min_4d(out_ptr, x, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, MB: tl.constexpr, DIM: tl.constexpr):
    if DIM == 0:
        ret = tl.reshape(tl.min(x, DIM), XB * YB * ZB * MB // XB)
        o_idx = tl.arange(0, XB * YB * ZB * MB // XB)
        tl.store(out_ptr + o_idx, ret)
    elif DIM == 1:
        ret = tl.reshape(tl.min(x, DIM), XB * YB * ZB * MB // YB)
        o_idx = tl.arange(0, XB * YB * ZB * MB // YB)
        tl.store(out_ptr + o_idx, ret)
    elif DIM == 2:
        ret = tl.reshape(tl.min(x, DIM), XB * YB * ZB * MB // ZB)
        o_idx = tl.arange(0, XB * YB * ZB * MB // ZB)
        tl.store(out_ptr + o_idx, ret)
    else:
        ret = tl.reshape(tl.min(x, DIM), XB * YB * ZB * MB // MB)
        o_idx = tl.arange(0, XB * YB * ZB * MB // MB)
        tl.store(out_ptr + o_idx, ret)


@triton.jit
def triton_min_kernel_4d(in_ptr, out_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, MB: tl.constexpr, DIM: tl.constexpr):
    xidx = tl.arange(0, XB)
    yidx = tl.arange(0, YB)
    zidx = tl.arange(0, ZB)
    midx = tl.arange(0, MB)

    idx = xidx[:, None, None, None] * YB * ZB * MB + yidx[None, :, None, None] * ZB * MB + zidx[None, None, :, None] * MB + midx[None, None, None, :]

    x = tl.load(in_ptr + idx)

    min_4d(out_ptr, x, XB, YB, ZB, MB, DIM)


def triton_min_4d(in_ptr, out_ptr, XB, YB, ZB, MB, dim):
    triton_min_kernel_4d[(1,)](in_ptr, out_ptr, XB, YB, ZB, MB, dim)


@pytest.mark.shape_4d_5d
@pytest.mark.parametrize('shape', [
    (2, 2, 4, 8)
])
@pytest.mark.parametrize('dtype', TestUtils.full_dtype)
@pytest.mark.parametrize('dim', [0])
def test_min_4d(dtype, shape, dim):
    x0 = test_common.generate_tensor(shape, dtype).npu()
    torch_res = torch_min_4d(x0, dim)
    triton_res = torch.empty_like(torch_res).npu()
    triton_min_4d(x0, triton_res, shape[0], shape[1], shape[2], shape[3], dim)

    test_common.validate_cmp(dtype, triton_res, torch_res)
# >>>>>>> test_min_4d


# <<<<<<< test_min_5d
def torch_min_5d(x0, dim):
    x0 = x0 if x0.device == "cpu" else x0.cpu()
    if x0.dtype in (torch.int8, torch.int16, torch.int32):
        x0 = x0.to(torch.int64)
    return torch.min(x0, dim=dim)[0]


@triton.jit
def min_5d(out_ptr, x, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, MB: tl.constexpr, NB: tl.constexpr, DIM: tl.constexpr):
    if DIM == 0:
        ret = tl.reshape(tl.min(x, DIM), XB * YB * ZB * MB * NB // XB)
        o_idx = tl.arange(0, XB * YB * ZB * MB * NB // XB)
        tl.store(out_ptr + o_idx, ret)
    elif DIM == 1:
        ret = tl.reshape(tl.min(x, DIM), XB * YB * ZB * MB * NB // YB)
        o_idx = tl.arange(0, XB * YB * ZB * MB * NB // YB)
        tl.store(out_ptr + o_idx, ret)
    elif DIM == 2:
        ret = tl.reshape(tl.min(x, DIM), XB * YB * ZB * MB * NB // ZB)
        o_idx = tl.arange(0, XB * YB * ZB * MB * NB // ZB)
        tl.store(out_ptr + o_idx, ret)
    elif DIM == 3:
        ret = tl.reshape(tl.min(x, DIM), XB * YB * ZB * MB * NB // MB)
        o_idx = tl.arange(0, XB * YB * ZB * MB * NB // MB)
        tl.store(out_ptr + o_idx, ret)
    else:
        ret = tl.reshape(tl.min(x, DIM), XB * YB * ZB * MB * NB // NB)
        o_idx = tl.arange(0, XB * YB * ZB * MB * NB // NB)
        tl.store(out_ptr + o_idx, ret)


@triton.jit
def triton_min_kernel_5d(in_ptr, out_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, MB: tl.constexpr, NB: tl.constexpr, DIM: tl.constexpr):
    xidx = tl.arange(0, XB)
    yidx = tl.arange(0, YB)
    zidx = tl.arange(0, ZB)
    midx = tl.arange(0, MB)
    nidx = tl.arange(0, NB)

    idx = xidx[:, None, None, None, None] * YB * ZB * MB * NB + yidx[None, :, None, None, None] * ZB * MB * NB + zidx[None, None, :, None, None] * MB * NB + midx[None, None, None, :, None] * NB + nidx[None, None, None, None, :] 

    x = tl.load(in_ptr + idx)

    min_5d(out_ptr, x, XB, YB, ZB, MB, NB, DIM)


def triton_min_5d(in_ptr, out_ptr, XB, YB, ZB, MB, NB, dim):
    triton_min_kernel_5d[(1,)](in_ptr, out_ptr, XB, YB, ZB, MB, NB, dim)


@pytest.mark.shape_4d_5d
@pytest.mark.parametrize('shape', [
    (2, 2, 2, 4, 8)
])
@pytest.mark.parametrize('dtype', TestUtils.full_dtype)
@pytest.mark.parametrize('dim', [0])
def test_min_5d(dtype, shape, dim):
    x0 = test_common.generate_tensor(shape, dtype).npu()
    torch_res = torch_min_5d(x0, dim)
    triton_res = torch.empty_like(torch_res).npu()
    triton_min_5d(x0, triton_res, shape[0], shape[1], shape[2], shape[3], shape[4], dim)

    test_common.validate_cmp(dtype, triton_res, torch_res)
# >>>>>>> test_min_5d