#!/usr/bin/env python3
"""
snapshot_analyze.py 单元测试
"""

import os
import sqlite3
import sys
import tempfile
import unittest

sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "skills", "ascend-npu-snapshot-analyzer", "scripts"))

import snapshot_queries as q
from snapshot_analyze import (
    _format_bytes,
    _health_status,
    _metric_status,
    analyze_overview,
    analyze_peak,
    analyze_fragment,
    analyze_leak,
    analyze_oom,
    analyze_compare,
)


def _setup_test_db(db_path):
    conn = sqlite3.connect(db_path)
    conn.executescript("""
        CREATE TABLE IF NOT EXISTS devices (
            id INTEGER PRIMARY KEY AUTOINCREMENT,
            device_index INTEGER UNIQUE NOT NULL,
            device_type TEXT DEFAULT 'Ascend-NPU'
        );
        CREATE TABLE IF NOT EXISTS call_stacks (
            id INTEGER PRIMARY KEY AUTOINCREMENT,
            stack_hash TEXT UNIQUE,
            frames_json TEXT
        );
        CREATE TABLE IF NOT EXISTS segments (
            id INTEGER PRIMARY KEY AUTOINCREMENT,
            device_id INTEGER NOT NULL,
            address INTEGER,
            total_size INTEGER,
            allocated_size INTEGER,
            active_size INTEGER,
            requested_size INTEGER,
            stream INTEGER,
            segment_type TEXT,
            pool_id_0 INTEGER,
            pool_id_1 INTEGER,
            is_expandable INTEGER,
            stack_id INTEGER,
            FOREIGN KEY (device_id) REFERENCES devices(id),
            FOREIGN KEY (stack_id) REFERENCES call_stacks(id)
        );
        CREATE TABLE IF NOT EXISTS blocks (
            id INTEGER PRIMARY KEY AUTOINCREMENT,
            segment_id INTEGER NOT NULL,
            address INTEGER,
            size INTEGER,
            requested_size INTEGER,
            state TEXT,
            stack_id INTEGER,
            FOREIGN KEY (segment_id) REFERENCES segments(id) ON DELETE CASCADE,
            FOREIGN KEY (stack_id) REFERENCES call_stacks(id)
        );
        CREATE TABLE IF NOT EXISTS traces (
            id INTEGER PRIMARY KEY AUTOINCREMENT,
            device_id INTEGER NOT NULL,
            trace_index INTEGER,
            action TEXT,
            addr INTEGER,
            device_free INTEGER,
            size INTEGER,
            stream INTEGER,
            stack_id INTEGER,
            FOREIGN KEY (device_id) REFERENCES devices(id),
            FOREIGN KEY (stack_id) REFERENCES call_stacks(id)
        );
    """)
    conn.commit()
    conn.close()


class TestFormatBytes(unittest.TestCase):
    def test_bytes(self):
        self.assertEqual(_format_bytes(512), "512 B")

    def test_kb(self):
        self.assertIn("KB", _format_bytes(2048))

    def test_mb(self):
        self.assertIn("MB", _format_bytes(5 * 1024 * 1024))

    def test_gb(self):
        self.assertIn("GB", _format_bytes(3 * 1024**3))


class TestHealthStatus(unittest.TestCase):
    def test_healthy(self):
        result = _health_status(3.0, 50, False)
        self.assertEqual(result["level"], "健康")

    def test_warn_frag(self):
        result = _health_status(10.0, 50, False)
        self.assertEqual(result["level"], "需关注")

    def test_warn_segments(self):
        result = _health_status(3.0, 150, False)
        self.assertEqual(result["level"], "需关注")

    def test_err_oom(self):
        result = _health_status(3.0, 50, True)
        self.assertEqual(result["level"], "严重")

    def test_err_frag(self):
        result = _health_status(20.0, 50, False)
        self.assertEqual(result["level"], "严重")


class TestMetricStatus(unittest.TestCase):
    def test_ok(self):
        self.assertEqual(_metric_status(3.0, (5, 15)), "[OK]")

    def test_warn(self):
        self.assertEqual(_metric_status(10.0, (5, 15)), "[WARN]")

    def test_err(self):
        self.assertEqual(_metric_status(20.0, (5, 15)), "[ERR]")


class TestAnalyzeOverview(unittest.TestCase):
    def setUp(self):
        self.tmpdir = tempfile.mkdtemp()
        self.db_path = os.path.join(self.tmpdir, "test.db")
        _setup_test_db(self.db_path)

        conn = sqlite3.connect(self.db_path)
        conn.execute("INSERT INTO devices (device_index) VALUES (0)")
        conn.execute(
            "INSERT INTO segments (device_id, address, total_size, allocated_size, active_size, requested_size, segment_type, is_expandable) VALUES (1, 0x1000, 1024*1024*1024, 800*1024*1024, 800*1024*1024, 800*1024*1024, 'large', 1)"
        )
        conn.commit()
        conn.close()

    def tearDown(self):
        import shutil

        shutil.rmtree(self.tmpdir, ignore_errors=True)

    def test_overview_structure(self):
        result = analyze_overview(self.db_path)
        self.assertEqual(result["mode"], "overview")
        self.assertIn("health", result)
        self.assertIn("summary", result)
        self.assertIn("devices", result)
        self.assertEqual(result["device_count"], 1)
        self.assertEqual(result["oom_count"], 0)

    def test_summary_values(self):
        result = analyze_overview(self.db_path)
        summary = result["summary"]
        self.assertGreater(summary["reserved"], 0)
        self.assertGreater(summary["allocated"], 0)
        self.assertIsNotNone(summary["frag_pct"])


class TestAnalyzePeak(unittest.TestCase):
    def setUp(self):
        self.tmpdir = tempfile.mkdtemp()
        self.db_path = os.path.join(self.tmpdir, "test.db")
        _setup_test_db(self.db_path)

        conn = sqlite3.connect(self.db_path)
        conn.execute("INSERT INTO devices (device_index) VALUES (0)")
        conn.execute(
            "INSERT INTO call_stacks (stack_hash, frames_json) VALUES ('hash1', '[\"a.py:10:func_a\", \"b.py:20:func_b\"]')"
        )
        conn.execute(
            "INSERT INTO segments (device_id, address, total_size, allocated_size, active_size, requested_size, segment_type, is_expandable) VALUES (1, 0x1000, 1024*1024*1024, 800*1024*1024, 800*1024*1024, 800*1024*1024, 'large', 1)"
        )
        conn.execute(
            "INSERT INTO blocks (segment_id, address, size, requested_size, state, stack_id) VALUES (1, 0x1000, 800*1024*1024, 800*1024*1024, 'active_allocated', 1)"
        )
        conn.execute(
            "INSERT INTO traces (device_id, trace_index, action, size, stack_id) VALUES (1, 0, 'segment_alloc', 1024*1024*1024, 1)"
        )
        conn.execute(
            "INSERT INTO traces (device_id, trace_index, action, size) VALUES (1, 1, 'segment_alloc', 512*1024*1024)"
        )
        conn.execute(
            "INSERT INTO traces (device_id, trace_index, action, size) VALUES (1, 2, 'segment_free', 512*1024*1024)"
        )
        conn.commit()
        conn.close()

    def tearDown(self):
        import shutil

        shutil.rmtree(self.tmpdir, ignore_errors=True)

    def test_peak_structure(self):
        result = analyze_peak(self.db_path)
        self.assertEqual(result["mode"], "peak")
        self.assertIn("peak", result)
        self.assertIn("baseline", result)
        self.assertIn("deltas", result)
        self.assertIn("peak_alloc_events", result)
        self.assertIn("peak_blocks", result)
        self.assertIn("devices", result)
        self.assertGreaterEqual(len(result["devices"]), 1)

    def test_peak_timeline(self):
        result = analyze_peak(self.db_path)
        self.assertGreater(result["peak"]["reserved"], 0)

    def test_peak_contributors_have_call_path(self):
        result = analyze_peak(self.db_path)
        events = result["peak_alloc_events"]["top"]
        self.assertGreater(len(events), 0)
        for evt in events:
            self.assertIn("call_path", evt)


class TestAnalyzeFragment(unittest.TestCase):
    def setUp(self):
        self.tmpdir = tempfile.mkdtemp()
        self.db_path = os.path.join(self.tmpdir, "test.db")
        _setup_test_db(self.db_path)

        conn = sqlite3.connect(self.db_path)
        conn.execute("INSERT INTO devices (device_index) VALUES (0)")
        conn.execute(
            "INSERT INTO segments (device_id, address, total_size, allocated_size, active_size, requested_size, segment_type, is_expandable) VALUES (1, 0x1000, 1024*1024*1024, 800*1024*1024, 800*1024*1024, 800*1024*1024, 'large', 1)"
        )
        conn.commit()
        conn.close()

    def tearDown(self):
        import shutil

        shutil.rmtree(self.tmpdir, ignore_errors=True)

    def test_fragment_structure(self):
        result = analyze_fragment(self.db_path)
        self.assertEqual(result["mode"], "fragment")
        self.assertIn("overall", result)
        self.assertIn("top_fragmented", result)

    def test_frag_pct_calculation(self):
        result = analyze_fragment(self.db_path)
        expected = round((1024 * 1024 * 1024 - 800 * 1024 * 1024) / (1024 * 1024 * 1024) * 100, 1)
        self.assertEqual(result["overall"]["frag_pct"], expected)


class TestAnalyzeLeak(unittest.TestCase):
    def setUp(self):
        self.tmpdir = tempfile.mkdtemp()
        self.db_path = os.path.join(self.tmpdir, "test.db")
        _setup_test_db(self.db_path)

        conn = sqlite3.connect(self.db_path)
        conn.execute("INSERT INTO devices (device_index) VALUES (0)")
        conn.execute(
            "INSERT INTO segments (device_id, address, total_size, allocated_size, active_size, requested_size, segment_type, is_expandable) VALUES (1, 0x1000, 1024*1024*1024, 800*1024*1024, 800*1024*1024, 800*1024*1024, 'large', 1)"
        )
        for i in range(10):
            conn.execute(
                "INSERT INTO traces (device_id, trace_index, action, size, addr) VALUES (1, ?, 'segment_alloc', 1024, 0x1000)",
                (i,),
            )
        conn.commit()
        conn.close()

    def tearDown(self):
        import shutil

        shutil.rmtree(self.tmpdir, ignore_errors=True)

    def test_leak_structure(self):
        result = analyze_leak(self.db_path)
        self.assertEqual(result["mode"], "leak")
        self.assertIn("risk", result)
        self.assertIn("monotonic_growth", result)
        self.assertIn("suspects", result)

    def test_monotonic_detection(self):
        result = analyze_leak(self.db_path)
        self.assertTrue(result["monotonic_growth"]["detected"])


class TestAnalyzeOOM(unittest.TestCase):
    def setUp(self):
        self.tmpdir = tempfile.mkdtemp()
        self.db_path = os.path.join(self.tmpdir, "test.db")
        _setup_test_db(self.db_path)

        conn = sqlite3.connect(self.db_path)
        conn.execute("INSERT INTO devices (device_index) VALUES (0)")
        conn.execute(
            "INSERT INTO traces (device_id, trace_index, action, device_free, size) VALUES (1, 100, 'oom', 128*1024*1024, 0)"
        )
        for i in range(50, 100):
            conn.execute(
                "INSERT INTO traces (device_id, trace_index, action, size) VALUES (1, ?, 'alloc', ?)",
                (i, 256 * 1024 * 1024),
            )
        conn.commit()
        conn.close()

    def tearDown(self):
        import shutil

        shutil.rmtree(self.tmpdir, ignore_errors=True)

    def test_oom_detected(self):
        result = analyze_oom(self.db_path)
        self.assertTrue(result["detected"])
        self.assertEqual(len(result["events"]), 1)

    def test_no_oom(self):
        db2 = os.path.join(self.tmpdir, "no_oom.db")
        _setup_test_db(db2)
        conn = sqlite3.connect(db2)
        conn.execute("INSERT INTO devices (device_index) VALUES (0)")
        conn.commit()
        conn.close()
        result = analyze_oom(db2)
        self.assertFalse(result["detected"])


class TestAnalyzeCompare(unittest.TestCase):
    def setUp(self):
        self.tmpdir = tempfile.mkdtemp()
        self.db_a = os.path.join(self.tmpdir, "a.db")
        self.db_b = os.path.join(self.tmpdir, "b.db")

        for db_path, addr, size in [(self.db_a, 0x1000, 1024), (self.db_b, 0x1000, 2048)]:
            _setup_test_db(db_path)
            conn = sqlite3.connect(db_path)
            conn.execute("INSERT INTO devices (device_index) VALUES (0)")
            conn.execute(
                "INSERT INTO segments (device_id, address, total_size, allocated_size, active_size, requested_size, segment_type, is_expandable) "
                "VALUES (1, ?, ?, ?, ?, ?, 'small', 0)",
                (addr, size, size, size, size),
            )
            conn.commit()
            conn.close()

        _setup_test_db(self.db_b)
        conn = sqlite3.connect(self.db_b)
        conn.execute(
            "INSERT INTO segments (device_id, address, total_size, allocated_size, active_size, requested_size, segment_type, is_expandable) "
            "VALUES (1, 0x2000, 512, 512, 512, 512, 'small', 0)"
        )
        conn.commit()
        conn.close()

    def tearDown(self):
        import shutil

        shutil.rmtree(self.tmpdir, ignore_errors=True)

    def test_compare_structure(self):
        result = analyze_compare(self.db_a, self.db_b)
        self.assertEqual(result["mode"], "compare")
        self.assertIn("metrics", result)
        self.assertIn("new_segments", result)
        self.assertIn("grown_segments", result)

    def test_compare_growth(self):
        result = analyze_compare(self.db_a, self.db_b)
        reserved_metric = next(m for m in result["metrics"] if m["name"] == "Reserved")
        self.assertGreater(reserved_metric["diff"], 0)

    def test_new_segment(self):
        result = analyze_compare(self.db_a, self.db_b)
        self.assertEqual(result["new_segment_count"], 1)

    def test_ref_not_found(self):
        result = analyze_compare(self.db_a, "nonexistent.db")
        self.assertIn("error", result)


class TestSnapshotQueries(unittest.TestCase):
    def setUp(self):
        self.tmpdir = tempfile.mkdtemp()
        self.db_path = os.path.join(self.tmpdir, "test.db")
        _setup_test_db(self.db_path)

        conn = sqlite3.connect(self.db_path)
        conn.execute("INSERT INTO devices (device_index) VALUES (0)")
        conn.execute("INSERT INTO devices (device_index) VALUES (1)")
        conn.execute(
            "INSERT INTO segments (device_id, address, total_size, allocated_size, active_size, requested_size, segment_type, is_expandable) VALUES (1, 0x1000, 1024*1024*1024, 800*1024*1024, 800*1024*1024, 800*1024*1024, 'large', 1)"
        )
        conn.execute(
            "INSERT INTO segments (device_id, address, total_size, allocated_size, active_size, requested_size, segment_type, is_expandable) VALUES (2, 0x2000, 512*1024*1024, 400*1024*1024, 400*1024*1024, 400*1024*1024, 'small', 0)"
        )
        conn.execute("INSERT INTO traces (device_id, trace_index, action, size) VALUES (1, 0, 'segment_alloc', 1024)")
        conn.execute("INSERT INTO traces (device_id, trace_index, action, size) VALUES (1, 1, 'alloc', 512)")
        conn.commit()
        conn.close()

    def tearDown(self):
        import shutil

        shutil.rmtree(self.tmpdir, ignore_errors=True)

    def test_get_device_overview(self):
        result = q.get_device_overview(self.db_path)
        self.assertEqual(len(result), 2)

    def test_get_device_count(self):
        self.assertEqual(q.get_device_count(self.db_path), 2)

    def test_get_segment_count(self):
        self.assertEqual(q.get_segment_count(self.db_path), 2)

    def test_get_trace_count(self):
        self.assertEqual(q.get_trace_count(self.db_path), 2)

    def test_get_block_state_dist(self):
        result = q.get_block_state_dist(self.db_path)
        self.assertIsInstance(result, list)

    def test_get_expansion_events(self):
        result = q.get_expansion_events(self.db_path)
        self.assertGreater(len(result), 0)

    def test_get_oom_events(self):
        result = q.get_oom_events(self.db_path)
        self.assertEqual(len(result), 0)

    def test_get_fragmentation_detail(self):
        result = q.get_fragmentation_detail(self.db_path)
        self.assertEqual(len(result), 2)

    def test_execute_sql(self):
        result = q.execute_sql(self.db_path, "SELECT COUNT(*) AS cnt FROM devices")
        self.assertEqual(result[0]["cnt"], 2)


if __name__ == "__main__":
    unittest.main()