"""Tests for grid_generator/runner.py — pure functions and CLI path."""

import argparse
import tempfile
from pathlib import Path
from unittest import mock

import pytest

from tools.perf_data_collection.grid_generator.runner import (
    iter_csv_files,
    load_csv_files,
    process_theory_csv,
    run_theory_mode,
)


class TestIterCsvFiles:
    def test_sorts_files(self):
        with tempfile.TemporaryDirectory() as td:
            datadir = Path(td)
            (datadir / "MatMulV2.csv").write_text("a")
            (datadir / "PadV3.csv").write_text("b")
            (datadir / "RmsNorm.csv").write_text("c")
            result = list(iter_csv_files(datadir))
            names = [p.name for p in result]
            assert "MatMulV2.csv" in names
            assert "PadV3.csv" in names
            assert "RmsNorm.csv" in names

    def test_excludes_tmp_files(self):
        with tempfile.TemporaryDirectory() as td:
            datadir = Path(td)
            (datadir / "MatMulV2.csv").write_text("a")
            (datadir / "MatMulV2.tmp.csv").write_text("b")
            result = list(iter_csv_files(datadir))
            names = [p.name for p in result]
            assert "MatMulV2.csv" in names
            assert "MatMulV2.tmp.csv" not in names

    def test_subdirs(self):
        with tempfile.TemporaryDirectory() as td:
            datadir = Path(td)
            sub = datadir / "sub"
            sub.mkdir()
            (sub / "Op.csv").write_text("x")
            result = list(iter_csv_files(datadir))
            assert len(result) == 1


class TestLoadCsvFiles:
    def test_empty_dir_raises(self):
        with tempfile.TemporaryDirectory() as td:
            datadir = Path(td)
            (datadir / "config.yaml").write_text("{}")
            with pytest.raises(ValueError):
                load_csv_files(datadir)

    def test_non_existent_dir_raises(self):
        with pytest.raises(ValueError):
            load_csv_files(Path("/nonexistent/path"))


class TestProcessTheoryCsv:
    def test_no_generator_returns_none(self):
        with tempfile.TemporaryDirectory() as td:
            csv_path = Path(td) / "UnknownKernel.csv"
            csv_path.write_text('Input Shapes,Average Duration(us)\n"128,5120",10.0\n')

            result = process_theory_csv(
                csv_path=csv_path,
                model_names=None,
                config={"assignments": {}, "patterns": {}},
                op_meta={},
                file_index=1,
                total_files=1,
            )
            assert result is None


class TestRunTheoryMode:
    def test_parses_target_models(self):
        with tempfile.TemporaryDirectory() as td:
            datadir = Path(td)
            (datadir / "config.yaml").write_text("assignments: {}\npatterns: {}\n")
            (datadir / "op_mapping.yaml").write_text("operator_mappings: {}\n")

            with (
                mock.patch.object(Path, "resolve", return_value=datadir),
                mock.patch(
                    "tools.perf_data_collection.grid_generator.runner.load_shape_grid_config",
                    return_value={"assignments": {}, "patterns": {}},
                ),
                mock.patch(
                    "tools.perf_data_collection.grid_generator.runner.load_op_mapping_metadata",
                    return_value={},
                ),
                mock.patch(
                    "tools.perf_data_collection.grid_generator.runner.iter_csv_files",
                    return_value=[],
                ),
            ):
                args = argparse.Namespace(
                    target_models="dsv3,qwen332b",
                    rows=0,
                    seed=0,
                    max_hbm_gb=32.0,
                )
                total, skipped = run_theory_mode(args, datadir, [])
                assert total == 0

    def test_no_target_models(self):
        with tempfile.TemporaryDirectory() as td:
            datadir = Path(td)
            (datadir / "config.yaml").write_text("assignments: {}\npatterns: {}\n")

            with (
                mock.patch.object(Path, "resolve", return_value=datadir),
                mock.patch(
                    "tools.perf_data_collection.grid_generator.runner.load_shape_grid_config",
                    return_value={"assignments": {}, "patterns": {}},
                ),
                mock.patch(
                    "tools.perf_data_collection.grid_generator.runner.load_op_mapping_metadata",
                    return_value={},
                ),
                mock.patch(
                    "tools.perf_data_collection.grid_generator.runner.iter_csv_files",
                    return_value=[],
                ),
            ):
                args = argparse.Namespace(
                    target_models=None,
                    rows=50,
                    seed=42,
                    max_hbm_gb=0,
                )
                total, skipped = run_theory_mode(args, datadir, [])
                assert total == 0

    def test_with_csv_files_skipped(self):
        with tempfile.TemporaryDirectory() as td:
            datadir = Path(td)
            (datadir / "config.yaml").write_text("assignments: {}\npatterns: {}\n")
            csv_file = Path(td) / "UnknownKernel.csv"
            csv_file.write_text('Input Shapes,Average Duration(us)\n"128,5120",10.0\n')

            with (
                mock.patch.object(Path, "resolve", return_value=datadir),
                mock.patch(
                    "tools.perf_data_collection.grid_generator.runner.load_shape_grid_config",
                    return_value={"assignments": {}, "patterns": {}},
                ),
                mock.patch(
                    "tools.perf_data_collection.grid_generator.runner.load_op_mapping_metadata",
                    return_value={},
                ),
            ):
                args = argparse.Namespace(
                    target_models=None,
                    rows=0,
                    seed=0,
                    max_hbm_gb=32.0,
                )
                total, skipped = run_theory_mode(args, datadir, [csv_file])
                assert total == 0
                assert len(skipped) == 1