# Copyright Huawei Technologies Co., Ltd. 2025-2025. All rights reserved.
import importlib
import sys
import types
import unittest
from unittest.mock import patch


PROFILER_INTERFACE_MODULE = "serving_cast.profiler.profiler_interface"
PROFILER_STIME_MODULE = "serving_cast.profiler.profiler_stime"
SERVICE_TYPE = "liuren_simulation"
STABLE_FILE_SIZE = 16
START_TIME = 12.5
END_TIME = 15.0
TASK_NAME = "task-7"
MISSING = object()


class RecordingProfilerBase:
    Level = "CLASS_LEVEL"

    def __init__(self, level):
        self.level = level
        self.calls = []

    def metric(self, name, value):
        self.calls.append(("metric", name, value))
        return self

    def event(self, event_name):
        self.calls.append(("event", event_name))
        return ("event", event_name)

    def span_start(self, span_name):
        self.calls.append(("span_start", span_name))
        return ("span_start", span_name)

    def span_end(self):
        self.calls.append(("span_end",))
        return "span_end"

    def add_meta_info(self, key, value):
        self.calls.append(("add_meta_info", key, value))
        return self


class RecordingInitProfiler:
    instances = []

    def __init__(self, level):
        self.level = level
        self.calls = []
        type(self).instances.append(self)

    def add_meta_info(self, key, value):
        self.calls.append(("add_meta_info", key, value))
        return self


def fresh_import(module_name):
    sys.modules.pop(module_name, None)
    return importlib.import_module(module_name)


def build_fake_ms_service_profiler(*, profiler_cls, level=MISSING, parse_main=None):
    package = types.ModuleType("ms_service_profiler")
    package.__path__ = []
    package.Profiler = profiler_cls
    if level is not MISSING:
        package.Level = level

    parse_module = types.ModuleType("ms_service_profiler.parse")
    parse_module.main = parse_main or (lambda: None)
    package.parse = parse_module
    return package, parse_module


class TestProfilerStime(unittest.TestCase):
    def tearDown(self):
        sys.modules.pop(PROFILER_STIME_MODULE, None)

    def test_import_uses_top_level_level(self):
        level_marker = object()
        package, parse_module = build_fake_ms_service_profiler(
            profiler_cls=RecordingProfilerBase,
            level=level_marker,
        )

        with patch.dict(
            sys.modules,
            {
                "ms_service_profiler": package,
                "ms_service_profiler.parse": parse_module,
            },
            clear=False,
        ):
            module = fresh_import(PROFILER_STIME_MODULE)

        self.assertIs(module.Level, level_marker)
        self.assertTrue(issubclass(module.SimProfiler, RecordingProfilerBase))

    def test_import_falls_back_to_profiler_level(self):
        package, parse_module = build_fake_ms_service_profiler(profiler_cls=RecordingProfilerBase)

        with patch.dict(
            sys.modules,
            {
                "ms_service_profiler": package,
                "ms_service_profiler.parse": parse_module,
            },
            clear=False,
        ):
            module = fresh_import(PROFILER_STIME_MODULE)

        self.assertEqual(module.Level, RecordingProfilerBase.Level)

    def test_import_requires_profiler_package(self):
        original_import = __import__

        def selective_import(name, globals=None, locals=None, fromlist=(), level=0):
            if name == "ms_service_profiler" or name.startswith("ms_service_profiler."):
                raise ImportError("blocked for test")
            return original_import(name, globals, locals, fromlist, level)

        with patch("builtins.__import__", side_effect=selective_import):
            with self.assertRaisesRegex(ImportError, "Please install ms_service_profiler"):
                fresh_import(PROFILER_STIME_MODULE)

    def test_import_requires_level_on_fallback(self):
        class ProfilerWithoutLevel:
            pass

        package, parse_module = build_fake_ms_service_profiler(profiler_cls=ProfilerWithoutLevel)

        with patch.dict(
            sys.modules,
            {
                "ms_service_profiler": package,
                "ms_service_profiler.parse": parse_module,
            },
            clear=False,
        ):
            with self.assertRaisesRegex(
                ImportError,
                r"ms_service_profiler\.Profiler has no Level; upgrade ms_service_profiler",
            ):
                fresh_import(PROFILER_STIME_MODULE)

    def test_parse_main_func_delegates_to_parse_module(self):
        parse_calls = []

        def fake_parse_main():
            parse_calls.append(list(sys.argv))

        package, parse_module = build_fake_ms_service_profiler(
            profiler_cls=RecordingProfilerBase,
            parse_main=fake_parse_main,
        )

        with patch.dict(
            sys.modules,
            {
                "ms_service_profiler": package,
                "ms_service_profiler.parse": parse_module,
            },
            clear=False,
        ):
            module = fresh_import(PROFILER_STIME_MODULE)
            original_argv = sys.argv[:]
            try:
                sys.argv = ["parse-profiler", "--input-path", "/tmp/demo"]
                module.parse_main_func()
            finally:
                sys.argv = original_argv

        self.assertEqual(parse_calls, [["parse-profiler", "--input-path", "/tmp/demo"]])

    def test_event_records_logical_timestamps_and_pid(self):
        package, parse_module = build_fake_ms_service_profiler(profiler_cls=RecordingProfilerBase)

        with patch.dict(
            sys.modules,
            {
                "ms_service_profiler": package,
                "ms_service_profiler.parse": parse_module,
            },
            clear=False,
        ):
            module = fresh_import(PROFILER_STIME_MODULE)
            profiler = module.SimProfiler("INFO")
            with (
                patch.object(module, "now", side_effect=[START_TIME, END_TIME]),
                patch.object(
                    module,
                    "current_task_name",
                    return_value=TASK_NAME,
                ),
            ):
                result = profiler.event("Decode")

        self.assertEqual(result, ("event", "Decode"))
        self.assertEqual(
            profiler.calls,
            [
                ("metric", "logical_start_time", START_TIME),
                ("metric", "logical_end_time", END_TIME),
                ("metric", "logical_pid", TASK_NAME),
                ("event", "Decode"),
            ],
        )

    def test_span_start_records_start_timestamp_and_pid(self):
        package, parse_module = build_fake_ms_service_profiler(profiler_cls=RecordingProfilerBase)

        with patch.dict(
            sys.modules,
            {
                "ms_service_profiler": package,
                "ms_service_profiler.parse": parse_module,
            },
            clear=False,
        ):
            module = fresh_import(PROFILER_STIME_MODULE)
            profiler = module.SimProfiler("INFO")
            with (
                patch.object(module, "now", return_value=START_TIME),
                patch.object(
                    module,
                    "current_task_name",
                    return_value=TASK_NAME,
                ),
            ):
                result = profiler.span_start("Prefill")

        self.assertEqual(result, ("span_start", "Prefill"))
        self.assertEqual(
            profiler.calls,
            [
                ("metric", "logical_start_time", START_TIME),
                ("metric", "logical_pid", TASK_NAME),
                ("span_start", "Prefill"),
            ],
        )

    def test_span_end_records_end_timestamp_and_pid(self):
        package, parse_module = build_fake_ms_service_profiler(profiler_cls=RecordingProfilerBase)

        with patch.dict(
            sys.modules,
            {
                "ms_service_profiler": package,
                "ms_service_profiler.parse": parse_module,
            },
            clear=False,
        ):
            module = fresh_import(PROFILER_STIME_MODULE)
            profiler = module.SimProfiler("INFO")
            with (
                patch.object(module, "now", return_value=END_TIME),
                patch.object(
                    module,
                    "current_task_name",
                    return_value=TASK_NAME,
                ),
            ):
                result = profiler.span_end()

        self.assertEqual(result, "span_end")
        self.assertEqual(
            profiler.calls,
            [
                ("metric", "logical_end_time", END_TIME),
                ("metric", "logical_pid", TASK_NAME),
                ("span_end",),
            ],
        )


class TestProfilerInterface(unittest.TestCase):
    def tearDown(self):
        sys.modules.pop(PROFILER_INTERFACE_MODULE, None)

    def test_import_without_supported_profiler_disables_profiling(self):
        empty_stime_module = types.ModuleType("serving_cast.profiler.profiler_stime")

        with patch.dict(
            sys.modules,
            {"serving_cast.profiler.profiler_stime": empty_stime_module},
            clear=False,
        ):
            module = fresh_import(PROFILER_INTERFACE_MODULE)

        self.assertFalse(module.is_profiling_ready())
        with self.assertRaisesRegex(ValueError, "profiling is not supported"):
            module.init_profiling()
        with self.assertRaisesRegex(ValueError, "profiling is not supported"):
            module.parse_profiling_results("/tmp/profiler")
        with self.assertRaisesRegex(RuntimeError, "profiling is not supported"):
            module.get_batch_type([])

    def test_supported_import_exposes_helpers_and_initializes_profiler(self):
        RecordingInitProfiler.instances.clear()
        parse_calls = []

        class LevelStub:
            INFO = "INFO"

        fake_stime = types.ModuleType("serving_cast.profiler.profiler_stime")
        fake_stime.Level = LevelStub
        fake_stime.SimProfiler = RecordingInitProfiler
        fake_stime.parse_main_func = lambda: parse_calls.append(list(sys.argv))

        fake_utils = types.ModuleType("serving_cast.profiler.profiler_utils")

        def fake_get_batch_type(payload):
            return ("batch", payload)

        def fake_get_iter_size_info(queue, increase_iter_size):
            return ("iter", queue, increase_iter_size)

        def fake_queue_profiler(before_queue, after_queue, queue_name):
            return ("queue", before_queue, after_queue, queue_name)

        def fake_record_kv_cache_free_blocks(current_event, req_id, num_free_blocks):
            return ("kv", current_event, req_id, num_free_blocks)

        fake_utils.get_batch_type = fake_get_batch_type
        fake_utils.get_iter_size_info = fake_get_iter_size_info
        fake_utils.queue_profiler = fake_queue_profiler
        fake_utils.record_kv_cache_free_blocks = fake_record_kv_cache_free_blocks

        with patch.dict(
            sys.modules,
            {
                "serving_cast.profiler.profiler_stime": fake_stime,
                "serving_cast.profiler.profiler_utils": fake_utils,
            },
            clear=False,
        ):
            module = fresh_import(PROFILER_INTERFACE_MODULE)
            module.init_profiling()

        self.assertTrue(module.is_profiling_ready())
        self.assertIs(module.Level, LevelStub)
        self.assertIs(module.get_batch_type, fake_get_batch_type)
        self.assertIs(module.get_iter_size_info, fake_get_iter_size_info)
        self.assertIs(module.queue_profiler, fake_queue_profiler)
        self.assertIs(module.record_kv_cache_free_blocks, fake_record_kv_cache_free_blocks)
        self.assertEqual(len(RecordingInitProfiler.instances), 1)
        self.assertEqual(RecordingInitProfiler.instances[0].level, LevelStub.INFO)
        self.assertEqual(
            RecordingInitProfiler.instances[0].calls,
            [("add_meta_info", "service_type", SERVICE_TYPE)],
        )
        self.assertEqual(parse_calls, [])

    def test_parse_profiling_results_waits_for_stable_size_before_parsing(self):
        parse_calls = []

        class LevelStub:
            INFO = "INFO"

        fake_stime = types.ModuleType("serving_cast.profiler.profiler_stime")
        fake_stime.Level = LevelStub
        fake_stime.SimProfiler = RecordingInitProfiler
        fake_stime.parse_main_func = lambda: parse_calls.append(list(sys.argv))

        fake_utils = types.ModuleType("serving_cast.profiler.profiler_utils")
        fake_utils.get_batch_type = lambda payload: payload
        fake_utils.get_iter_size_info = lambda queue, increase_iter_size: (queue, increase_iter_size)
        fake_utils.queue_profiler = lambda before_queue, after_queue, queue_name: None
        fake_utils.record_kv_cache_free_blocks = lambda current_event, req_id, num_free_blocks: None

        profile_dir = "/tmp/profiling-run"
        expected_argv = [
            "python -m ms_service_profiler.parse",
            "--input-path",
            profile_dir,
            "--output-path",
            profile_dir + "_parsed_result",
        ]

        with patch.dict(
            sys.modules,
            {
                "serving_cast.profiler.profiler_stime": fake_stime,
                "serving_cast.profiler.profiler_utils": fake_utils,
            },
            clear=False,
        ):
            module = fresh_import(PROFILER_INTERFACE_MODULE)
            original_argv = sys.argv[:]
            try:
                with (
                    patch.object(module.os, "walk", return_value=[(profile_dir, [], ["result.bin"])]),
                    patch.object(
                        module.os.path,
                        "getsize",
                        return_value=STABLE_FILE_SIZE,
                    ),
                    patch.object(module.time, "sleep") as sleep_mock,
                ):
                    module.parse_profiling_results(profile_dir)
            finally:
                sys.argv = original_argv

        self.assertEqual(parse_calls, [expected_argv])
        self.assertEqual(sys.argv, original_argv)
        self.assertEqual(sleep_mock.call_count, 10)
        sleep_mock.assert_called_with(0.1)


if __name__ == "__main__":
    unittest.main()