# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.


import triton
import triton.language as tl
import torch
import pytest
import test_common
from test_common import TestUtils


@triton.jit
def fn_broadcast_1d(output_ptr, x_ptr,  XS: tl.constexpr, YS: tl.constexpr):
    xidx = tl.arange(0, XS)[None, :]
    base = tl.load(x_ptr + xidx)
    out = base.broadcast_to((YS, XS))
    oidx = tl.arange(0, YS)[:, None] * XS + tl.arange(0, XS)[None, :]
    tl.store(output_ptr + oidx, out)


@pytest.mark.parametrize('shape', TestUtils.test_shape1d)
@pytest.mark.parametrize('dtype', TestUtils.full_dtype)
def test_npu_1d(shape, dtype):
    XS = shape[0]
    YS = 4

    x = test_common.generate_tensor((XS, ), dtype=dtype).npu()
    std = torch.broadcast_to(x, (YS, XS))
    output = test_common.generate_tensor((YS, XS), dtype=dtype).npu()
    fn_broadcast_1d[1, 1, 1](output, x, XS, YS)
    test_common.validate_cmp(dtype, std, output)


@triton.jit
def fn_broadcast_2d(output_ptr, x_ptr, NUMEL: tl.constexpr, XS: tl.constexpr, YS: tl.constexpr, ZS: tl.constexpr):
    zoffset = tl.program_id(0) * ZS
    zidx = tl.arange(0, ZS)[None, :]
    base = tl.load(x_ptr + zoffset + zidx)
    out = base.broadcast_to((YS, ZS))
    oidx = zoffset * YS + tl.arange(0, YS)[:, None] * ZS + zidx
    tl.store(output_ptr + oidx, out)


@pytest.mark.parametrize('shape', TestUtils.test_shape2d)
@pytest.mark.parametrize('dtype', TestUtils.full_dtype)
def test_npu_2d(shape, dtype):
    XS = shape[0]
    ZS = shape[1]
    YS = 4
    NUMEL = XS * ZS

    x = test_common.generate_tensor((XS, 1, ZS), dtype=dtype).npu()  # randn not support int type
    std = torch.broadcast_to(x, (XS, YS, ZS))
    output = test_common.generate_tensor((XS, YS, ZS), dtype=dtype).npu()
    fn_broadcast_2d[XS, 1, 1](output, x, NUMEL, XS, YS, ZS)
    test_common.validate_cmp(dtype, std, output)


@triton.jit
def triton_broadcast_to_dim0(in_ptr0, out_ptr0, L: tl.constexpr, M: tl.constexpr, N: tl.constexpr):
    lblk_idx = tl.arange(0, L)
    mblk_idx = tl.arange(0, M)
    nblk_idx = tl.arange(0, N)
    idx = tl.arange(0, 1)[:, None, None] * N * M + mblk_idx[None, :, None] * N + nblk_idx[None, None, :]
    odx = lblk_idx[:, None, None] * N * M + mblk_idx[None, :, None] * N + nblk_idx[None, None, :]
    x = tl.load(in_ptr0 + idx)
    x1 = tl.load(out_ptr0 + odx)
    ret = tl.broadcast(x, x1)
    tl.store(out_ptr0 + odx, ret)


@pytest.mark.parametrize('dtype', TestUtils.full_dtype)
@pytest.mark.parametrize('shape', TestUtils.test_shape3d)
def test_broadcast_to_dim0(shape, dtype):
    L, M, N = shape
    x0 = test_common.generate_tensor(shape=(1, M, N), dtype=dtype).npu()
    ans = x0.repeat(L, 1, 1)
    output = torch.zeros((L, M, N), dtype=eval('torch.' + dtype)).npu()
    triton_broadcast_to_dim0[1, 1, 1](x0, output, L, M, N)
    test_common.validate_cmp(dtype, output, ans)


@triton.jit
def triton_broadcast_to_dim1(in_ptr0, out_ptr0, L: tl.constexpr, M: tl.constexpr, N: tl.constexpr):
    lblk_idx = tl.arange(0, L)
    mblk_idx = tl.arange(0, M)
    nblk_idx = tl.arange(0, N)
    idx = lblk_idx[:, None, None] * N * 1 + tl.arange(0, 1)[None, :, None] * N + nblk_idx[None, None, :]
    odx = lblk_idx[:, None, None] * N * M + mblk_idx[None, :, None] * N + nblk_idx[None, None, :]
    x = tl.load(in_ptr0 + idx)
    x1 = tl.load(out_ptr0 + odx)
    ret = tl.broadcast(x, x1)
    tl.store(out_ptr0 + odx, ret)


@pytest.mark.parametrize('dtype', TestUtils.full_dtype)
@pytest.mark.parametrize('shape', TestUtils.test_shape3d)
def test_broadcast_to_dim1(shape, dtype):
    L, M, N = shape
    x0 = test_common.generate_tensor(shape=(L, 1, N), dtype=dtype).npu()
    ans = x0.repeat(1, M, 1)
    output = torch.zeros((L, M, N), dtype=eval('torch.' + dtype)).npu()
    triton_broadcast_to_dim1[1, 1, 1](x0, output, L, M, N)
    test_common.validate_cmp(dtype, output, ans)


@triton.jit
def triton_broadcast_to_dim2(in_ptr0, out_ptr0, L: tl.constexpr, M: tl.constexpr, N: tl.constexpr):
    lblk_idx = tl.arange(0, L)
    mblk_idx = tl.arange(0, M)
    nblk_idx = tl.arange(0, N)
    idx = lblk_idx[:, None, None] * 1 * M + mblk_idx[None, :, None] * 1 + tl.arange(0, 1)[None, None, :]
    odx = lblk_idx[:, None, None] * N * M + mblk_idx[None, :, None] * N + nblk_idx[None, None, :]
    x = tl.load(in_ptr0 + idx)
    x1 = tl.load(out_ptr0 + odx)
    ret = tl.broadcast(x, x1)
    tl.store(out_ptr0 + odx, ret)


@pytest.mark.parametrize('dtype', TestUtils.full_dtype)
@pytest.mark.parametrize('shape', TestUtils.test_shape3d)
def test_broadcast_to_dim2(shape, dtype):
    L, M, N = shape
    x0 = test_common.generate_tensor(shape=(L, M, 1), dtype=dtype).npu()
    ans = x0.repeat(1, 1, N)
    output = torch.zeros((L, M, N), dtype=eval('torch.' + dtype)).npu()
    triton_broadcast_to_dim2[1, 1, 1](x0, output, L, M, N)
    test_common.validate_cmp(dtype, output, ans)


@triton.jit
def triton_broadcast_to_dim01(in_ptr0, out_ptr0, L: tl.constexpr, M: tl.constexpr, N: tl.constexpr):
    lblk_idx = tl.arange(0, L)
    mblk_idx = tl.arange(0, M)
    nblk_idx = tl.arange(0, N)
    idx = tl.arange(0, 1)[:, None, None] * N * 1 + tl.arange(0, 1)[None, :, None] * N + nblk_idx[None, None, :]
    odx = lblk_idx[:, None, None] * N * M + mblk_idx[None, :, None] * N + nblk_idx[None, None, :]
    x = tl.load(in_ptr0 + idx)
    x1 = tl.load(out_ptr0 + odx)
    ret = tl.broadcast(x, x1)
    tl.store(out_ptr0 + odx, ret)


@pytest.mark.parametrize('dtype', TestUtils.full_dtype)
@pytest.mark.parametrize('shape', TestUtils.test_shape3d)
def test_broadcast_to_dim01(shape, dtype):
    L, M, N = shape
    x0 = test_common.generate_tensor(shape=(1, 1, N), dtype=dtype).npu()
    ans = x0.repeat(L, M, 1)
    output = torch.zeros((L, M, N), dtype=eval('torch.' + dtype)).npu()
    triton_broadcast_to_dim01[1, 1, 1](x0, output, L, M, N)
    test_common.validate_cmp(dtype, output, ans)


@triton.jit
def triton_broadcast_to_dim02(in_ptr0, out_ptr0, L: tl.constexpr, M: tl.constexpr, N: tl.constexpr):
    lblk_idx = tl.arange(0, L)
    mblk_idx = tl.arange(0, M)
    nblk_idx = tl.arange(0, N)
    idx = tl.arange(0, 1)[:, None, None] * M * 1 + mblk_idx[None, :, None] * 1 + tl.arange(0, 1)[None, None, :]
    odx = lblk_idx[:, None, None] * N * M + mblk_idx[None, :, None] * N + nblk_idx[None, None, :]
    x = tl.load(in_ptr0 + idx)
    x1 = tl.load(out_ptr0 + odx)
    ret = tl.broadcast(x, x1)
    tl.store(out_ptr0 + odx, ret)


@pytest.mark.parametrize('dtype', TestUtils.full_dtype)
@pytest.mark.parametrize('shape', TestUtils.test_shape3d)
def test_broadcast_to_dim02(shape, dtype):
    L, M, N = shape
    x0 = test_common.generate_tensor(shape=(1, M, 1), dtype=dtype).npu()
    ans = x0.repeat(L, 1, N)
    output = torch.zeros((L, M, N), dtype=eval('torch.' + dtype)).npu()
    triton_broadcast_to_dim02[1, 1, 1](x0, output, L, M, N)
    test_common.validate_cmp(dtype, output, ans)


@triton.jit
def triton_broadcast_to_dim12(in_ptr0, out_ptr0, L: tl.constexpr, M: tl.constexpr, N: tl.constexpr):
    lblk_idx = tl.arange(0, L)
    mblk_idx = tl.arange(0, M)
    nblk_idx = tl.arange(0, N)
    idx = lblk_idx[:, None, None] * 1 * 1 + tl.arange(0, 1)[None, :, None] * 1 + tl.arange(0, 1)[None, None, :]
    odx = lblk_idx[:, None, None] * N * M + mblk_idx[None, :, None] * N + nblk_idx[None, None, :]
    x = tl.load(in_ptr0 + idx)
    x1 = tl.load(out_ptr0 + odx)
    ret = tl.broadcast(x, x1)
    tl.store(out_ptr0 + odx, ret)


@pytest.mark.parametrize('dtype', TestUtils.full_dtype)
@pytest.mark.parametrize('shape', TestUtils.test_shape3d)
def test_broadcast_to_dim12(shape, dtype):
    L, M, N = shape
    x0 = test_common.generate_tensor(shape=(L, 1, 1), dtype=dtype).npu()
    ans = x0.repeat(1, M, N)
    output = torch.zeros((L, M, N), dtype=eval('torch.' + dtype)).npu()
    triton_broadcast_to_dim12[1, 1, 1](x0, output, L, M, N)
    test_common.validate_cmp(dtype, output, ans)


@triton.jit
def fn_broadcast_multi_d(to_ptr, from_ptr, F_L: tl.constexpr, F_M: tl.constexpr, F_N: tl.constexpr, F_X: tl.constexpr, F_Y: tl.constexpr, T_L: tl.constexpr, T_M: tl.constexpr, T_N: tl.constexpr, T_X: tl.constexpr, T_Y: tl.constexpr):
    from_offsets = tl.arange(0, F_L)
    if F_M is not None:
        from_offsets = from_offsets[:, None] * F_M + tl.arange(0, F_M)[None, :]
    if F_N is not None:
        from_offsets = from_offsets[:, :, None] * F_N + tl.arange(0, F_N)[None, None, :]
    if F_X is not None:
        from_offsets = from_offsets[:, :, :, None] * F_X + tl.arange(0, F_X)[None, None, None, :]
    if F_Y is not None:
        from_offsets = from_offsets[:, :, :, :, None] * F_Y + tl.arange(0, F_Y)[None, None, None, None, :]

    to_offsets = tl.arange(0, T_L)
    if T_M is not None:
        to_offsets = to_offsets[:, None] * T_M + tl.arange(0, T_M)[None, :]
    if T_N is not None:
        to_offsets = to_offsets[:, :, None] * T_N + tl.arange(0, T_N)[None, None, :]
    if T_X is not None:
        to_offsets = to_offsets[:, :, :, None] * T_X + tl.arange(0, T_X)[None, None, None, :]
    if T_Y is not None:
        to_offsets = to_offsets[:, :, :, :, None] * T_Y + tl.arange(0, T_Y)[None, None, None, None, :]

    from_data = tl.load(from_ptr + from_offsets)
    to_data = tl.load(to_ptr + to_offsets)
    ret_data = tl.broadcast(from_data, to_data)

    tl.store(to_ptr + to_offsets, ret_data)


@pytest.mark.shape_4d_5d
@pytest.mark.parametrize('shapes', [
    [(1, 64, 16, 1), (2, 64, 16, 2)],
    [(8, 1, 1, 2), (8, 8, 4, 2)],
])
@pytest.mark.parametrize('dtype', ["int32", "int64", "float16", "float32", "bfloat16"])
def test_broadcast_to_4d(shapes, dtype):
    from_shape, to_shape = shapes
    dtype = eval(f"torch.{dtype}")

    x = torch.randint(0, 8, from_shape, dtype=dtype).npu()
    y = torch.randint(0, 8, to_shape, dtype=dtype).npu()
    expected = x.expand(to_shape)

    grid = (1, )
    triton_from_shape = [*from_shape]
    triton_to_shape = [*to_shape]
    while len(triton_from_shape) < 5:
        triton_from_shape.append(None)
        triton_to_shape.append(None)
    fn_broadcast_multi_d[grid](y, x, *triton_from_shape, *triton_to_shape)
    assert(torch.equal(y, expected))


@pytest.mark.shape_4d_5d
@pytest.mark.parametrize('dtype', ["int32", "int64", "float16", "float32", "bfloat16"])
@pytest.mark.parametrize('shapes', [
    [(1, 4, 2, 1, 4), (2, 4, 2, 8, 4)],
    [(3, 1, 2, 1, 4), (3, 4, 2, 8, 4)],
])
def test_broadcast_to_5d(shapes, dtype):
    from_shape, to_shape = shapes
    dtype = eval(f"torch.{dtype}")

    x = torch.randint(0, 8, from_shape, dtype=dtype).npu()
    y = torch.randint(0, 8, to_shape, dtype=dtype).npu()
    expected = x.expand(to_shape)

    grid = (1, )
    triton_from_shape = [*from_shape]
    triton_to_shape = [*to_shape]
    while len(triton_from_shape) < 5:
        triton_from_shape.append(None)
        triton_to_shape.append(None)
    fn_broadcast_multi_d[grid](y, x, *triton_from_shape, *triton_to_shape)
    assert(torch.equal(y, expected))


@triton.jit
def fn_broadcast(in_ptr0, out_ptr0, L: tl.constexpr, M: tl.constexpr, N: tl.constexpr):
    lblk_idx = tl.arange(0, L)
    mblk_idx = tl.arange(0, M)
    nblk_idx = tl.arange(0, N)
    idx = tl.arange(0, 1)[:, None, None] * N * M + mblk_idx[None, :, None] * N + nblk_idx[None, None, :]
    odx = lblk_idx[:, None, None] * N * M + mblk_idx[None, :, None] * N + nblk_idx[None, None, :]
    x = tl.load(in_ptr0 + idx)
    x1 = tl.load(out_ptr0 + odx)
    ret = tl.broadcast(x, x1)
    tl.store(out_ptr0 + odx, ret)

XS : tl.constexpr = 2
YS : tl.constexpr = 4
ZS : tl.constexpr = 8
@pytest.mark.parametrize('dtype', ["uint8","int8","int16","int32", "int64", "float16", "float32", "bfloat16","bool"])
def test_broadcast_alltype(dtype):
    input = test_common.generate_tensor((1, YS, ZS), dtype).npu()
    ans = input.repeat(XS, 1, 1)
    output = torch.zeros((XS, YS, ZS), dtype=eval('torch.' + dtype)).npu()
    fn_broadcast[1, 1, 1](input, output, XS, YS, ZS)
    test_common.validate_cmp(dtype, ans, output)