"""Tests for memory_estimator.py — HBM memory estimation utilities."""

from tools.perf_data_collection.memory_estimator import (
    DTYPE_BYTES,
    DEFAULT_BYTES_PER_ELEMENT,
    DEFAULT_MAX_BYTES,
    dtype_to_bytes,
    estimate_row_memory,
    estimate_tensor_bytes,
    exceeds_memory_budget,
    format_bytes,
    parse_dtype_from_template_row,
    _parse_dtype_list,
)


class TestDtypeToBytes:
    def test_known_dtype_with_prefix(self):
        assert dtype_to_bytes("DT_FLOAT") == 4
        assert dtype_to_bytes("DT_FLOAT16") == 2
        assert dtype_to_bytes("DT_BF16") == 2
        assert dtype_to_bytes("DT_BFLOAT16") == 2
        assert dtype_to_bytes("DT_FLOAT32") == 4
        assert dtype_to_bytes("DT_FLOAT64") == 8
        assert dtype_to_bytes("DT_INT8") == 1
        assert dtype_to_bytes("DT_INT16") == 2
        assert dtype_to_bytes("DT_INT32") == 4
        assert dtype_to_bytes("DT_INT64") == 8
        assert dtype_to_bytes("DT_UINT8") == 1
        assert dtype_to_bytes("DT_UINT16") == 2
        assert dtype_to_bytes("DT_UINT32") == 4
        assert dtype_to_bytes("DT_UINT64") == 8
        assert dtype_to_bytes("DT_BOOL") == 1
        assert dtype_to_bytes("DT_COMPLEX64") == 8
        assert dtype_to_bytes("DT_COMPLEX128") == 16
        assert dtype_to_bytes("DT_FLOAT8_E4M3") == 1
        assert dtype_to_bytes("DT_FLOAT8_E5M2") == 1
        assert dtype_to_bytes("DT_FLOAT8") == 1

    def test_known_dtype_without_prefix(self):
        assert dtype_to_bytes("FLOAT") == 4
        assert dtype_to_bytes("INT8") == 1
        assert dtype_to_bytes("INT16") == 2
        assert dtype_to_bytes("INT32") == 4
        assert dtype_to_bytes("INT64") == 8
        assert dtype_to_bytes("UINT8") == 1
        assert dtype_to_bytes("BOOL") == 1

    def test_case_insensitive(self):
        assert dtype_to_bytes("dt_float16") == 2
        assert dtype_to_bytes("Dt_Int32") == 4
        assert dtype_to_bytes("dt_bfloat16") == 2

    def test_whitespace_handling(self):
        assert dtype_to_bytes("  DT_FLOAT  ") == 4
        assert dtype_to_bytes("DT_BF16\n") == 2

    def test_unknown_dtype_falls_back_to_fp16(self):
        assert dtype_to_bytes("UNKNOWN_TYPE") == DEFAULT_BYTES_PER_ELEMENT
        assert dtype_to_bytes("") == DEFAULT_BYTES_PER_ELEMENT

    def test_all_known_dtypes_have_valid_size(self):
        for name, size in DTYPE_BYTES.items():
            assert isinstance(size, int) and size > 0, f"Bad size for {name}: {size}"


class TestParseDtypeList:
    def test_semicolon_separated(self):
        assert _parse_dtype_list("DT_BF16;DT_BF16;DT_INT32") == [
            "DT_BF16",
            "DT_BF16",
            "DT_INT32",
        ]

    def test_space_separated(self):
        assert _parse_dtype_list("DT_BF16 DT_BF16  DT_INT32") == [
            "DT_BF16",
            "DT_BF16",
            "DT_INT32",
        ]

    def test_single_dtype(self):
        assert _parse_dtype_list("DT_FLOAT16") == ["DT_FLOAT16"]

    def test_empty_string(self):
        assert _parse_dtype_list("") == []

    def test_quoted_string(self):
        assert _parse_dtype_list('"DT_BF16;DT_INT32"') == ["DT_BF16", "DT_INT32"]

    def test_none_coerced_to_empty_string_list(self):
        result = _parse_dtype_list(None)
        assert result == []


class TestEstimateTensorBytes:
    def test_basic(self):
        assert estimate_tensor_bytes((128, 5120), 2) == 128 * 5120 * 2

    def test_empty_shape(self):
        assert estimate_tensor_bytes((), 4) == 0

    def test_single_dim(self):
        assert estimate_tensor_bytes((1024,), 4) == 1024 * 4

    def test_large_shape(self):
        assert estimate_tensor_bytes((4096, 8192), 1) == 4096 * 8192

    def test_zero_in_shape(self):
        assert estimate_tensor_bytes((0, 100), 2) == 0


class TestEstimateRowMemory:
    def test_basic_two_inputs(self):
        total = estimate_row_memory(
            input_shapes=[(128, 5120), (5120, 768)],
            output_shapes=[(128, 768)],
            input_dtypes=["DT_BF16", "DT_BF16"],
            output_dtypes=["DT_BF16"],
        )
        expected = 128 * 5120 * 2 + 5120 * 768 * 2 + 128 * 768 * 2
        assert total == expected

    def test_output_dtypes_default_to_input_dtypes(self):
        total = estimate_row_memory(
            input_shapes=[(100, 200)],
            output_shapes=[(100, 300)],
            input_dtypes=["DT_FLOAT32"],
            output_dtypes=None,
        )
        expected = 100 * 200 * 4 + 100 * 300 * 4
        assert total == expected

    def test_missing_dtype_falls_back(self):
        total = estimate_row_memory(
            input_shapes=[(10, 10), (10, 20)],
            output_shapes=[(10, 20)],
            input_dtypes=["DT_BF16"],
            output_dtypes=["DT_BF16"],
        )
        expected = 10 * 10 * 2 + 10 * 20 * 2 + 10 * 20 * 2
        assert total == expected

    def test_fp8_estimation(self):
        total = estimate_row_memory(
            input_shapes=[(4096, 8192)],
            output_shapes=[(4096, 4096)],
            input_dtypes=["DT_FLOAT8_E4M3"],
            output_dtypes=["DT_FLOAT8_E4M3"],
        )
        expected = 4096 * 8192 * 1 + 4096 * 4096 * 1
        assert total == expected

    def test_empty_inputs(self):
        total = estimate_row_memory(
            input_shapes=[],
            output_shapes=[(100, 200)],
            input_dtypes=[],
            output_dtypes=["DT_BF16"],
        )
        assert total == 100 * 200 * 2


class TestExceedsMemoryBudget:
    def test_under_budget(self):
        exceeded, est = exceeds_memory_budget(
            input_shapes=[(128, 5120)],
            output_shapes=[(128, 768)],
            input_dtypes=["DT_BF16"],
            output_dtypes=["DT_BF16"],
        )
        assert not exceeded
        assert est < DEFAULT_MAX_BYTES

    def test_over_budget(self):
        exceeded, est = exceeds_memory_budget(
            input_shapes=[(1 << 20, 4096), (1 << 20, 4096)],
            output_shapes=[(1 << 20, 4096)],
            input_dtypes=["DT_FLOAT32", "DT_FLOAT32"],
            output_dtypes=["DT_FLOAT32"],
        )
        assert exceeded
        assert est > DEFAULT_MAX_BYTES

    def test_custom_budget(self):
        tiny_budget = 1000
        exceeded, est = exceeds_memory_budget(
            input_shapes=[(100, 100)],
            output_shapes=[],
            input_dtypes=["DT_FLOAT32"],
            max_bytes=tiny_budget,
        )
        assert exceeded
        assert est > tiny_budget

    def test_exact_budget_boundary(self):
        budget = 100 * 200 * 2
        exceeded, _ = exceeds_memory_budget(
            input_shapes=[(100, 200)],
            output_shapes=[],
            input_dtypes=["DT_BF16"],
            max_bytes=budget,
        )
        assert not exceeded


class TestFormatBytes:
    def test_bytes(self):
        assert format_bytes(500) == "500 B"

    def test_kib(self):
        assert format_bytes(2048) == "2.00 KiB"

    def test_mib(self):
        assert "MiB" in format_bytes(5 * 1024 * 1024)

    def test_gib(self):
        assert "GiB" in format_bytes(3 * 1024**3)

    def test_zero(self):
        assert format_bytes(0) == "0 B"

    def test_gib_precise(self):
        assert format_bytes(2 * 1024**3) == "2.00 GiB"


class TestParseDtypeFromTemplateRow:
    def test_basic(self):
        row = {
            "Input Data Types": '"DT_BF16;DT_BF16;INT32"',
            "Output Data Types": '"DT_BF16;DT_BF16;DT_BF16"',
        }
        inputs, outputs = parse_dtype_from_template_row(row)
        assert inputs == ["DT_BF16", "DT_BF16", "INT32"]
        assert outputs == ["DT_BF16", "DT_BF16", "DT_BF16"]

    def test_empty_row(self):
        row = {}
        inputs, outputs = parse_dtype_from_template_row(row)
        assert inputs == []
        assert outputs == []