# Copyright (c) 2025 Huawei Technologies Co., Ltd.
# This program is free software, you can redistribute it and/or modify it under the terms and conditions of
# CANN Open Software License Agreement Version 2.0 (the "License").
# Please refer to the License for details. You may not use this file except in compliance with the License.
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
# See LICENSE in the root of the software repository for the full text of the License.

from typing import Tuple
import pytest

import asc
from asc.runtime import config
import asc.lib.runtime as rt
import asc.lib.host as host

try:
    import torch
except ModuleNotFoundError:
    pytest.skip("torch is not installed", allow_module_level=True)


@asc.jit
def matmul_kernel(a: asc.GlobalAddress, b: asc.GlobalAddress, c: asc.GlobalAddress, tiling: asc.adv.TCubeTiling,
                  workspace: asc.GlobalAddress, quant_vector: asc.GlobalAddress):
    tiling.share_l1_size = asc.property(asc.TOTAL_L1_SIZE)
    tiling.share_l0c_size = asc.property(asc.TOTAL_L0C_SIZE)
    offset_a, offset_b, offset_c, offset_quant = calc_offsets(tiling)
    a_global = asc.GlobalTensor()
    b_global = asc.GlobalTensor()
    c_global = asc.GlobalTensor()
    quant_vector_global = asc.GlobalTensor()
    a_global.set_global_buffer(a + offset_a)
    b_global.set_global_buffer(b + offset_b)
    c_global.set_global_buffer(c + offset_c)
    quant_vector_global.set_global_buffer(quant_vector + offset_quant)
    matmul = asc.adv.Matmul(
        a=asc.adv.MatmulType(asc.TPosition.GM, asc.CubeFormat.ND, a_global.dtype),
        b=asc.adv.MatmulType(asc.TPosition.GM, asc.CubeFormat.ND, b_global.dtype),
        c=asc.adv.MatmulType(asc.TPosition.GM, asc.CubeFormat.ND, c_global.dtype),
    )
    pipe = asc.TPipe()
    asc.adv.register_matmul(pipe, workspace, matmul, tiling)
    matmul.set_org_shape(tiling.m, tiling.n, tiling.k_a)
    matmul.set_single_shape(tiling.single_core_m, tiling.single_core_n, tiling.single_core_k)
    matmul.set_tensor_a(a_global)
    matmul.set_tensor_b(b_global)
    matmul.set_quant_vector(quant_vector_global)
    matmul.iterate_all(c_global)
    matmul.end()
    asc.pipe_barrier(asc.PipeID.PIPE_ALL)


@asc.jit
def calc_offsets(tiling: asc.adv.TCubeTiling) -> Tuple[int, int, int]:
    block_idx = asc.get_block_idx()
    temp0 = tiling.m.ceildiv(tiling.single_core_m)
    temp1 = tiling.k_a.ceildiv(tiling.single_core_k)
    temp2 = tiling.used_core_num // temp1
    m_index = (block_idx % temp2) % temp0
    n_index = (block_idx % temp2) // temp0
    offset_a = m_index * tiling.k_a * tiling.single_core_m
    offset_b = n_index * tiling.single_core_n
    offset_c = m_index * tiling.n * tiling.single_core_m + n_index * tiling.single_core_n
    offset_quant = n_index * tiling.single_core_n
    return offset_a, offset_b, offset_c, offset_quant


def matmul_launch(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, workspace: torch.Tensor, 
                    tiling: asc.adv.TCubeTiling, quant_vector: torch.Tensor) -> torch.Tensor:
    assert a.shape[1] == b.shape[0], "Matrices must be compatible for a multiplication"
    matmul_kernel[tiling.used_core_num, rt.current_stream()](a, b, c, tiling, workspace, quant_vector)
    return c


def generate_tiling(m, n, k):
    matmul_tiling = host.MultiCoreMatmulTiling(host.get_ascendc_platform())

    matmul_tiling.set_a_type(host.TPosition.GM, host.CubeFormat.ND, host.DataType.DT_INT8, False)
    matmul_tiling.set_b_type(host.TPosition.GM, host.CubeFormat.ND, host.DataType.DT_INT8, False)
    matmul_tiling.set_c_type(host.TPosition.VECCALC, host.CubeFormat.ND, host.DataType.DT_FLOAT16)

    matmul_tiling.set_dim(16)
    matmul_tiling.set_org_shape(m, n, k)
    matmul_tiling.set_shape(m, n, k)
    matmul_tiling.set_traverse(host.MatrixTraverse.FIRSTM)
    matmul_tiling.set_buffer_space(-1, -1, -1)

    tiling = asc.adv.TCubeTiling()
    matmul_tiling.get_tiling(tiling)
    return tiling


def gen_golden(torch_output: torch.Tensor, quant_vector: torch.Tensor):
    m, n = torch_output.shape
    golden = torch_output.to(torch.float16)
    quant_vector = quant_vector.view(torch.uint32)
    for index, data in enumerate(quant_vector):
        # 1 sign bit, 8 exponent bits and 10 mantissa bits
        quant_vector[index] = torch.bitwise_and(data, 0xFFFFE000)
    quant_vector = quant_vector.view(torch.float32)
    for i in range(m):
        golden[i, :] = golden[i, :] * quant_vector
    quant_vector_uint64 = torch.zeros(quant_vector.shape, dtype=torch.uint64)
    quant_vector_uint64 = quant_vector.view(torch.uint32).to(torch.uint64)
    return golden, quant_vector_uint64


param_list = [
    [torch.int8, (256, 256, 256)],
    [torch.int8, (256, 128, 64)],
    [torch.int8, (64, 128, 256)],
]


BACKENDS = [
    # config.Backend.Model,
    config.Backend.NPU,
]


@pytest.mark.parametrize("dtype, size", param_list)
@pytest.mark.parametrize("backend", BACKENDS)
@pytest.mark.skip(reason="Tensor self not implemented for DT_INT32")
def test_matmul_quant(dtype, size, backend: config.Backend):
    config.set_platform(backend)
    device = "npu" if config.Backend(backend) == config.Backend.NPU else "cpu"
    m, k, n = size
    a = torch.randint(-5, 5, (m, k), dtype=dtype, device=device)
    b = torch.randint(-5, 5, (k, n), dtype=dtype, device=device)
    c = torch.zeros((m, n), dtype=torch.float16, device=device) # int8 in float16 out
    workspace = torch.zeros(16 * 1024 * 1024, dtype=torch.uint8, device=device)
    quant_vector = (torch.rand((1, n), dtype=torch.float32, device=device) - .5) * 10
    torch_output = a.to(torch.int32) @ b.to(torch.int32)
    golden, quant_vector = gen_golden(torch_output, quant_vector)
    tiling = generate_tiling(m, n, k)
    output = matmul_launch(a, b, c, workspace, tiling, quant_vector)
    assert torch.allclose(output, golden, atol=1e-3)