#!/usr/bin/env python3
# coding: utf-8
# Copyright (c) 2026 Huawei Technologies Co., Ltd.
# This program is free software, you can redistribute it and/or modify it under the terms and conditions of
# CANN Open Software License Agreement Version 2.0 (the "License").
# Please refer to the License for details. You may not use this file except in compliance with the License.
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
# See LICENSE in the root of the software repository for the full text of the License.
# -----------------------------------------------------------------------------------------------------------
"""
Cast+Matmul 融合算子 ST 测试脚本。
场景:先 Cast 输入到目标 dtype,再执行 Matmul。
支持 pytest 参数化执行和直接执行两种模式。
"""
import os
import sys

import pytest
import pypto
import torch
import numpy as np

from testcase.matmul_ub2l1_test_case import (
    CastMatmulConfig,
    CAST_RIGHT_MATMUL_TESTS,
    CAST_LEFT_MATMUL_TESTS,
    CAST_BOTH_MATMUL_TESTS,
)


@pypto.frontend.jit(debug_options={"runtime_debug_mode": 0, "compile_debug_mode": 0})
def cast_matmul_pto_kernel(
    a_tensor: pypto.Tensor(),
    b_tensor: pypto.Tensor(),
    out_tensor: pypto.Tensor(),
    config: CastMatmulConfig,
):
    m, k, n = config.shape
    m_view, n_view = config.view_shape

    pypto.set_cube_tile_shapes(*config.cube_tile_shape)

    m_loop = (m + m_view - 1) // m_view
    n_loop = (n + n_view - 1) // n_view
    # 当设置scope大于5000,即5001以上时,开启mix场景,走入UB2L1
    pypto.set_pass_options(sg_set_scope=10000)
    for m_idx in pypto.loop(0, m_loop, 1, name="LOOP_L0_mIdx", idx_name="m_idx"):
        for n_idx in pypto.loop(0, n_loop, 1, name="LOOP_L0_nIdx", idx_name="n_idx"):
            mode = pypto.CastMode.CAST_NONE
            if config.matmul_pto_dtype == pypto.DT_INT8:
                mode = pypto.CastMode.CAST_TRUNC
            if config.a_trans:
                a_tile = a_tensor[:, m_idx * m_view: m_idx * m_view + m_view]
            else:
                a_tile = a_tensor[m_idx * m_view: m_idx * m_view + m_view, :]

            if config.a_cast:
                pypto.set_vec_tile_shapes(*config.a_vec_tile_shape)
                a_compute = pypto.cast(a_tile, config.matmul_pto_dtype, mode)
            else:
                a_compute = a_tile

            if config.b_trans:
                b_tile = b_tensor[n_idx * n_view: n_idx * n_view + n_view, :]
            else:
                b_tile = b_tensor[:, n_idx * n_view: n_idx * n_view + n_view]

            if config.b_cast:
                pypto.set_vec_tile_shapes(*config.b_vec_tile_shape)
                b_compute = pypto.cast(b_tile, config.matmul_pto_dtype, mode)
            else:
                b_compute = b_tile

            out_view = pypto.matmul(
                a_compute,
                b_compute,
                out_dtype=config.out_pto_dtype,
                a_trans=config.a_trans,
                b_trans=config.b_trans,
            )

            out_tensor[
                m_idx * m_view: m_idx * m_view + m_view,
                n_idx * n_view: n_idx * n_view + n_view,
            ] = out_view
    # 运行完后设置回-1,关闭mix
    pypto.set_pass_options(sg_set_scope=-1)


def run_cast_matmul_test(case: dict):
    device_id = int(os.environ.get("TILE_FWK_DEVICE_ID", 0))
    torch.npu.set_device(device_id)

    config = CastMatmulConfig.from_test_case(case)

    m, k, n = config.shape
    a_shape = [k, m] if config.a_trans else [m, k]
    b_shape = [n, k] if config.b_trans else [k, n]
    c_shape = [m, n]

    a_input_torch_dtype = CastMatmulConfig.get_torch_dtype(case["a_input_dtype"])
    b_input_torch_dtype = CastMatmulConfig.get_torch_dtype(case["b_input_dtype"])
    c_torch_dtype = CastMatmulConfig.get_torch_dtype(case["out_dtype"])

    if a_input_torch_dtype == torch.int8:
        a_tensor_cpu = torch.randint(-5, 6, a_shape, dtype=a_input_torch_dtype)
    else:
        a_tensor_cpu = torch.rand(a_shape, dtype=a_input_torch_dtype)

    if b_input_torch_dtype == torch.int8:
        b_tensor_cpu = torch.randint(-5, 6, b_shape, dtype=b_input_torch_dtype)
    else:
        b_tensor_cpu = torch.rand(b_shape, dtype=b_input_torch_dtype)

    matmul_dtype = CastMatmulConfig.get_torch_dtype(case["matmul_dtype"])

    a_cpu = a_tensor_cpu.to(matmul_dtype).T if config.a_trans else a_tensor_cpu.to(matmul_dtype)
    b_cpu = b_tensor_cpu.to(matmul_dtype).T if config.b_trans else b_tensor_cpu.to(matmul_dtype)

    if matmul_dtype == torch.int8:
        golden = torch.matmul(a_cpu.to(torch.int32), b_cpu.to(torch.int32)).to(c_torch_dtype)
    else:
        golden = torch.matmul(a_cpu.to(torch.float32), b_cpu.to(torch.float32)).to(c_torch_dtype)

    a_tensor = a_tensor_cpu.to(f"npu:{device_id}")
    b_tensor = b_tensor_cpu.to(f"npu:{device_id}")
    c_tensor = torch.zeros(c_shape, dtype=c_torch_dtype, device=f"npu:{device_id}")

    cast_matmul_pto_kernel(a_tensor, b_tensor, c_tensor, config)

    atol, rtol = CastMatmulConfig.get_tolerance(case["out_dtype"])

    assert torch.allclose(
        c_tensor.cpu(), golden.cpu(), atol=atol, rtol=rtol
    ), f"Test case {case['id']} ({case['name']}) failed"


ALL_CAST_MATMUL_TESTS = (
    CAST_RIGHT_MATMUL_TESTS
    + CAST_LEFT_MATMUL_TESTS
    + CAST_BOTH_MATMUL_TESTS
)


@pytest.mark.parametrize("case", [
    pytest.param(case, marks=pytest.mark.soc(*case["products"]))
    for case in ALL_CAST_MATMUL_TESTS
])
def test_cast_matmul(case: dict):
    run_cast_matmul_test(case)


def run_cast_matmul_demo(run_mode):
    m_size, k_size, n_size = 256, 256, 256
    m_view_size, n_view_size = 128, 128

    if run_mode == "npu":
        mode = pypto.RunMode.NPU
    elif run_mode == "sim":
        mode = pypto.RunMode.SIM
    else:
        raise ValueError(f"Invalid run_mode: {run_mode}. Must be 'npu' or 'sim'")

    @pypto.frontend.jit(
        debug_options={"runtime_debug_mode": 0, "compile_debug_mode": 0},
        runtime_options={"run_mode": mode}
    )
    def cast_matmul_demo_kernel(
        a: pypto.Tensor([], pypto.DT_FP32),
        b: pypto.Tensor([], pypto.DT_FP16),
        out: pypto.Tensor([], pypto.DT_FP16),
    ):
        pypto.set_cube_tile_shapes([128, 128], [128, 128], [128, 128])
        pypto.set_pass_options(sg_set_scope=10000)
        m_loop = (m_size + m_view_size - 1) // m_view_size
        n_loop = (n_size + n_view_size - 1) // n_view_size

        for m_idx in pypto.loop(0, m_loop, 1, name="LOOP_L0_mIdx", idx_name="m_idx"):
            for n_idx in pypto.loop(0, n_loop, 1, name="LOOP_L0_nIdx", idx_name="n_idx"):
                a_tile = a[m_idx * m_view_size: m_idx * m_view_size + m_view_size, :]
                pypto.set_vec_tile_shapes(m_view_size, k_size)
                a_fp16_tile = pypto.cast(a_tile, pypto.DT_FP16)

                b_view = b[:, n_idx * n_view_size: n_idx * n_view_size + n_view_size]
                out_view = pypto.matmul(a_fp16_tile, b_view, pypto.DT_FP16)
                out[
                    m_idx * m_view_size: m_idx * m_view_size + m_view_size,
                    n_idx * n_view_size: n_idx * n_view_size + n_view_size,
                ] = out_view

    device = "npu:0" if run_mode == "npu" else "cpu"
    a = torch.randn([m_size, k_size], dtype=torch.float32, device=device)
    b = torch.randn([k_size, n_size], dtype=torch.float16, device=device)
    out = torch.empty(m_size, n_size, dtype=torch.float16, device=device)
    cast_matmul_demo_kernel(a, b, out)


if __name__ == "__main__":
    run_cast_matmul_demo("npu")