# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.

import logging
import pytest
import torch
import torch_npu
import triton
import triton.language as tl

import acc_util
import test_common
from test_common import TestUtils, avoid_not_support, get_dtype_size


@triton.jit
def matmul_kernel(
        a_ptr, b_ptr, c_ptr,
        M: tl.constexpr,
        N: tl.constexpr,
        K: tl.constexpr,
        acc_dtype: tl.constexpr,
        stride_am: tl.constexpr,
        stride_ak: tl.constexpr,
        stride_bk: tl.constexpr,
        stride_bn: tl.constexpr,
        stride_cm: tl.constexpr,
        stride_cn: tl.constexpr,
        BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
):
    pid = tl.program_id(axis=0)
    num_pid_n = tl.cdiv(N, BLOCK_N)
    pid_m = pid // num_pid_n
    pid_n = pid % num_pid_n

    offs_am = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M))
    offs_bn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N))
    offs_k = tl.arange(0, BLOCK_K)
    a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
    b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)

    accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=acc_dtype)
    for k in range(0, tl.cdiv(K, BLOCK_K)):
        a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_K, other=0.0)
        b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_K, other=0.0)
        accumulator = tl.dot(a, b, accumulator, out_dtype=acc_dtype)
        a_ptrs += BLOCK_K * stride_ak
        b_ptrs += BLOCK_K * stride_bk
    c = accumulator.to(c_ptr.dtype.element_ty)

    offs_cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
    c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
    c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
    tl.store(c_ptrs, c, mask=c_mask)

@avoid_not_support('matmul')
@pytest.mark.parametrize('shape', TestUtils.test_shape2d)
@pytest.mark.parametrize('dtype', TestUtils.dtype_list)
def test_matmul(shape, dtype):
    M, N, K = shape[0], shape[0], shape[1]
    # 32byte/Dtype_bytes
    kalign = 32 // get_dtype_size(dtype)  
    BLOCK_M, BLOCK_N, BLOCK_K = min(max(M, 16), 32), min(max(N, 16), 32), min(max(K, kalign), 32)
    a = test_common.generate_tensor((M, K), dtype)
    b = test_common.generate_tensor((K, N), dtype)

    if dtype == "int8":
        triton_res = torch.zeros((M, N), dtype=torch.int32).npu()
        accumulator_type = tl.int32
    else:
        triton_res = torch.zeros((M, N), dtype=eval('torch.' + dtype)).npu()
        accumulator_type = tl.float32
    grid = (triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N),)
    
    matmul_kernel[grid](a.npu(), b.npu(), triton_res, M, N, K, accumulator_type,
                        a.stride(0), a.stride(1), b.stride(0), b.stride(1),
                        triton_res.stride(0), triton_res.stride(1),
                        BLOCK_M, BLOCK_N, BLOCK_K)

    a_gold = a.to(torch.float32)
    b_gold = b.to(torch.float32)
    cpu_res = torch.mm(a_gold, b_gold)

    if dtype == "int8":
        # torch_npu do not support int8 matmul
        a_npu = a.npu().to(torch.float32)
        b_npu = b.npu().to(torch.float32)
        torch_res = torch.mm(a_npu, b_npu)
        triton_res = triton_res.to(torch.float32)
    else:
        a_npu = a.npu()
        b_npu = b.npu()
        torch_res = torch.mm(a_npu, b_npu)

    try:
        print("starting compare of cpu vs triton:")
        acc_util.assert_close(cpu_res, triton_res)
    except Exception as e:
        print(e)
        print("starting compare of cpu vs triton vs torch_npu:")
        acc_util.benchmark_compare_close(cpu_res, triton_res, torch_res)
    print("PASSED")


@avoid_not_support('matmul')
@pytest.mark.parametrize('batch', TestUtils.batch)
@pytest.mark.parametrize('shape', TestUtils.test_shape2d)
@pytest.mark.parametrize('dtype', TestUtils.dtype_list)
def test_batch_matmul(shape, dtype, batch):
    M, N, K = shape[0], shape[0], shape[1]
    # 32byte/Dtype_bytes
    kalign = 32 // get_dtype_size(dtype)  
    BLOCK_M, BLOCK_N, BLOCK_K = min(max(M, 16), 32), min(max(N, 16), 32), min(max(K, kalign), 32)

    aa = test_common.generate_tensor((batch, M, K), dtype)
    bb = test_common.generate_tensor((batch, K, N), dtype)

    if dtype == "int8":
        final_triton_res = torch.zeros((batch, M, N), dtype=torch.int32).npu()
        accumulator_type = tl.int32
    else:
        final_triton_res = torch.zeros((batch, M, N), dtype=eval('torch.' + dtype)).npu()
        accumulator_type = tl.float32
    
    for i in range(0, batch):
        if dtype == "int8":
            triton_res = torch.zeros((M, N), dtype=torch.int32).npu()
        else:
            triton_res = torch.zeros((M, N), dtype=eval('torch.' + dtype)).npu()
        grid = (triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N),)
        a = aa[i]
        b = bb[i]
        matmul_kernel[grid](a.npu(), b.npu(), triton_res, M, N, K, accumulator_type,
                        a.stride(0), a.stride(1), b.stride(0), b.stride(1),
                        triton_res.stride(0), triton_res.stride(1),
                        BLOCK_M, BLOCK_N, BLOCK_K)
        final_triton_res[i] = triton_res

    a_gold = aa.to(torch.float32)
    b_gold = bb.to(torch.float32)
    cpu_res = torch.bmm(a_gold, b_gold)

    if dtype == "int8":
        a_npu = aa.npu().to(torch.float32)
        b_npu = bb.npu().to(torch.float32)
        final_triton_res = final_triton_res.to(torch.float32)
    else:
        a_npu = aa.npu()
        b_npu = bb.npu()
    torch_res = torch.bmm(a_npu, b_npu)
    
    try:
        print("starting compare of cpu vs triton:")
        acc_util.assert_close(cpu_res, final_triton_res)
    except Exception as e:
        print(e)
        print("starting compare of cpu vs triton vs torch_npu:")
        acc_util.benchmark_compare_close(cpu_res, final_triton_res, torch_res)
    print("PASSED")


if __name__ == "__main__":
    test_matmul((16, 32), 'float32')
    test_matmul((16, 32), 'int8')
    test_batch_matmul(2, (16, 32), 'float32')
    test_batch_matmul(2, (16, 32), 'int8')