# Copyright (c) 2025 Huawei Technologies Co., Ltd.

# This program is free software, you can redistribute it and/or modify it under the terms and conditions of

# CANN Open Software License Agreement Version 2.0 (the "License").

# Please refer to the License for details. You may not use this file except in compliance with the License.

# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,

# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.

# See LICENSE in the root of the software repository for the full text of the License.



import pytest



import asc

import asc.runtime.config as config

import asc.lib.runtime as rt

from asc.language.core.dtype import KnownTypes as KT

from asc.language.core.struct import Field, Struct, StructField



try:

    import torch

except ModuleNotFoundError:

    pytest.skip("torch is not installed", allow_module_level=True)





class BlockConfig(Struct):

    block_length = Field(dtype=KT.int32, default=0)





class TileConfig(Struct):

    tile_length = Field(dtype=KT.int32, default=0)

    tile_num = Field(dtype=KT.int32, default=0)





class KernelConfig(Struct):

    block = StructField(BlockConfig)

    tile = StructField(TileConfig)





@asc.jit

def vadd_kernel(x: asc.GlobalAddress, y: asc.GlobalAddress, z: asc.GlobalAddress, buffer_num: asc.ConstExpr[int],

                kernel_config: KernelConfig):

    offset = asc.get_block_idx() * kernel_config.block.block_length

    x_gm = asc.GlobalTensor()

    y_gm = asc.GlobalTensor()

    z_gm = asc.GlobalTensor()

    x_gm.set_global_buffer(x + offset)

    y_gm.set_global_buffer(y + offset)

    z_gm.set_global_buffer(z + offset)

    pipe = asc.TPipe()

    in_queue_x = asc.TQue(asc.TPosition.VECIN, 1)

    in_queue_y = asc.TQue(asc.TPosition.VECIN, 1)

    out_queue_z = asc.TQue(asc.TPosition.VECOUT, 1)

    pipe.init_buffer(que=in_queue_x, num=buffer_num, len=kernel_config.tile.tile_length * x.dtype.sizeof())

    pipe.init_buffer(que=in_queue_y, num=buffer_num, len=kernel_config.tile.tile_length * y.dtype.sizeof())

    pipe.init_buffer(que=out_queue_z, num=buffer_num, len=kernel_config.tile.tile_length * z.dtype.sizeof())

    for i in range(kernel_config.tile.tile_num):

        copy_in(i, x_gm, y_gm, in_queue_x, in_queue_y, kernel_config.tile.tile_length)

        compute(z_gm, in_queue_x, in_queue_y, out_queue_z, kernel_config.tile.tile_length)

        copy_out(i, z_gm, out_queue_z, kernel_config.tile.tile_length)





@asc.jit

def copy_in(i: int, x_gm: asc.GlobalAddress, y_gm: asc.GlobalAddress, in_queue_x: asc.TQue, in_queue_y: asc.TQue,

            tile_length: int):

    x_local = in_queue_x.alloc_tensor(x_gm.dtype)

    y_local = in_queue_y.alloc_tensor(y_gm.dtype)

    asc.data_copy(x_local, x_gm[i * tile_length:], count=tile_length)

    asc.data_copy(y_local, y_gm[i * tile_length:], count=tile_length)

    in_queue_x.enque(x_local)

    in_queue_y.enque(y_local)





@asc.jit

def compute(z_gm: asc.GlobalTensor, in_queue_x: asc.TQue, in_queue_y: asc.TQue, out_queue_z: asc.TQue,

            tile_length: int):

    # "z_gm" is passed here to obtain dtype

    x_local = in_queue_x.deque(z_gm.dtype)

    y_local = in_queue_y.deque(z_gm.dtype)

    z_local = out_queue_z.alloc_tensor(z_gm.dtype)

    asc.add(z_local, x_local, y_local, tile_length)

    out_queue_z.enque(z_local)

    in_queue_x.free_tensor(x_local)

    in_queue_y.free_tensor(y_local)





@asc.jit

def copy_out(i: int, z_gm: asc.GlobalTensor, out_queue_z: asc.TQue, tile_length: int):

    z_local = out_queue_z.deque(z_gm.dtype)

    asc.data_copy(z_gm[i * tile_length:], z_local, count=tile_length)

    out_queue_z.free_tensor(z_local)





def vadd_launch(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:

    assert x.shape == y.shape

    assert x.dtype == y.dtype

    z = torch.zeros_like(x)

    total_length = z.numel()

    use_core_num = 16

    block_length = (total_length + use_core_num - 1) // use_core_num

    tile_length = 512

    tile_num = (block_length + tile_length - 1) // tile_length

    buffer_num = 1

    block_config = BlockConfig(block_length=block_length)

    tile_config = TileConfig(tile_length=tile_length, tile_num=0)

    tile_config.tile_num = tile_num

    kernel_config = KernelConfig(tile=tile_config)

    kernel_config.block = block_config

    vadd_kernel[use_core_num, rt.current_stream()](x, y, z, buffer_num, kernel_config)

    return z





param_list = [

    [torch.float32, (1000,)],

    [torch.float32, (1,)],

    [torch.float32, (9999,)],

    [torch.float16, (2048,)],

    [torch.int32, (8192,)],

    [torch.int16, (8192,)],

    [torch.int32, (153, 834)],

]





BACKENDS = [

    # config.Backend.Model,

    config.Backend.NPU,

]





@pytest.mark.parametrize("dtype, size", param_list)

@pytest.mark.parametrize("backend", BACKENDS)

def test_vadd_tiling(dtype, size, backend: config.Backend):

    config.set_platform(backend)

    device = "npu" if config.Backend(backend) == config.Backend.NPU else "cpu"

    if dtype in {torch.float16, torch.float32}:

        x = torch.randn(size, dtype=dtype, device=device)

        y = torch.randn(size, dtype=dtype, device=device)

    else:

        x = torch.randint(-100, 99, size, dtype=dtype, device=device)

        y = torch.randint(-100, 99, size, dtype=dtype, device=device)

    z = vadd_launch(x, y)

    assert torch.allclose(z, x + y)