import importlib
import os
import urllib.request
from pathlib import Path

import pytest

import test_common


def discover_kernels():
    kernels = []
    kernels_root_path = Path(__file__).parents[0]
    for p in kernels_root_path.rglob("*.py"):
        if not p.is_file():
            continue
        if p.parent == kernels_root_path:
            continue
        rel = p.relative_to(kernels_root_path)
        if len(rel.parts) == 1 or p.name == "__init__.py":
            continue
        module_path = ".".join(rel.with_suffix("").parts)
        kernels.append((module_path, p.stem))
    return sorted(kernels, key=lambda x: x[1])


KERNEL_ITEMS = discover_kernels()


@pytest.mark.parametrize("module_path, kernel_name", KERNEL_ITEMS)
def test_triton_kernel(module_path, kernel_name, pytestconfig):
    selected = pytestconfig.getoption("kernel")
    if selected:
        if kernel_name not in selected:
            pytest.skip(f"skip {kernel_name} due to --kernel filter")
    base_url = "https://triton-ascend-artifacts.obs.cn-southwest-2.myhuaweicloud.com"
    rel = module_path
    parts = rel.split(".") if rel else []
    pt_url = f"{base_url}/test/kernels/{parts[0]}_pt/{kernel_name}.pt"
    local_pt = Path(__file__).parent / f"{kernel_name}.pt"
    downloaded = False
    if not local_pt.exists():
        try:
            urllib.request.urlretrieve(pt_url, local_pt)
            downloaded = True
        except Exception as e:
            pytest.fail(f"Failed to download the {kernel_name}.pt file. Please check whether the {kernel_name}.pt file has been uploaded to the OBS bucket: {e}")
    try:
        mod = importlib.import_module(module_path)
    except Exception as e:
        pytest.fail(f"import {module_path} failed: {e}")

    if hasattr(mod, kernel_name):
        kernel_attr = kernel_name
    else:
        candidates = [a for a in dir(mod) if a.endswith("_kernel")]
        kernel_attr = candidates[0] if candidates else None

    if not kernel_attr:
        pytest.fail(f"No kernel callable found in {module_path}")

    kernel_callable = getattr(mod, kernel_attr)

    def runner(input_data, grid):
        kernel_callable[grid](**input_data)

    try:
        test_common.run_and_compare_ptfile(str(local_pt), runner, device_type='npu')
    finally:
        if downloaded and local_pt.exists():
            local_pt.unlink()