from __future__ import annotations

import inspect
from dataclasses import dataclass
from pathlib import Path

import pytest

try:
    from ._shared import get_contract_path, install_paths, install_stubs, run_contract_test_for_file
except ImportError:
    try:
        from plugin_contracts._shared import (
            get_contract_path,
            install_paths,
            install_stubs,
            run_contract_test_for_file,
        )
    except ImportError:
        from _shared import get_contract_path, install_paths, install_stubs, run_contract_test_for_file

install_paths()
install_stubs()

NUM_GPUS = 0

from slime.utils.misc import load_function
from slime.utils.types import Sample


def run_contract_test_file() -> None:
    run_contract_test_for_file(
        __file__,
        path_args=[
            "custom-rollout-log-function-path",
            "custom-eval-rollout-log-function-path",
            "custom-reward-post-process-path",
            "custom-convert-samples-to-train-data-path",
            "rollout-data-postprocess-path",
        ],
    )


def reference_custom_rollout_log(rollout_id, args, samples, rollout_extra_metrics, rollout_time) -> bool:
    args.logged_rollout_id = rollout_id
    return True


def reference_custom_eval_rollout_log(rollout_id, args, data, extra_metrics) -> bool:
    args.logged_eval_rollout_id = rollout_id
    return True


def reference_reward_post_process(args, samples):
    raw_rewards = [sample.reward for sample in samples]
    rewards = [reward + 1.0 for reward in raw_rewards]
    return raw_rewards, rewards


def reference_convert_samples_to_train_data(args, samples):
    return {
        "tokens": [sample.tokens for sample in samples],
        "response_lengths": [sample.response_length for sample in samples],
        "rewards": [sample.reward for sample in samples],
        "raw_reward": [sample.reward for sample in samples],
        "truncated": [0 for _ in samples],
        "sample_indices": [sample.index for sample in samples],
        "loss_masks": [sample.loss_mask for sample in samples],
    }


def reference_rollout_data_postprocess(args) -> None:
    args.rollout_data_postprocess_called = True


def make_sample(index: int, reward: float = 1.0) -> Sample:
    return Sample(
        index=index,
        reward=reward,
        tokens=[index, index + 1],
        response_length=2,
        status=Sample.Status.COMPLETED,
        loss_mask=[1, 1],
    )


@dataclass(frozen=True)
class HookCase:
    name: str
    env_key: str
    default_path: str
    source_path: str
    runtime_marker: str
    expected_params: tuple[str, ...]
    invoke: object


def invoke_custom_rollout_log(fn):
    args = type("Args", (), {})()
    assert isinstance(fn(3, args, [Sample(index=0)], {"reward": 1.0}, 0.5), bool)
    assert args.logged_rollout_id == 3


def invoke_custom_eval_rollout_log(fn):
    args = type("Args", (), {})()
    sample = Sample(index=0, reward=1.0)
    assert isinstance(
        fn(4, args, {"eval_set": {"rewards": [1.0], "truncated": [False], "samples": [sample]}}, {"acc": 1.0}), bool
    )
    assert args.logged_eval_rollout_id == 4


def invoke_reward_post_process(fn):
    raw_rewards, rewards = fn(type("Args", (), {})(), [make_sample(0, 0.5), make_sample(1, 1.5)])
    assert len(raw_rewards) == len(rewards) == 2


def invoke_convert_samples_to_train_data(fn):
    train_data = fn(type("Args", (), {})(), [make_sample(0, 0.5), make_sample(1, 1.5)])
    assert {"tokens", "response_lengths", "rewards", "raw_reward", "truncated", "sample_indices", "loss_masks"} <= set(
        train_data
    )


def invoke_rollout_data_postprocess(fn):
    args = type("Args", (), {})()
    assert fn(args) is None
    assert args.rollout_data_postprocess_called is True


HOOK_CASES = [
    HookCase(
        "custom_rollout_log",
        "CUSTOM_ROLLOUT_LOG_FUNCTION_PATH",
        "plugin_contracts.test_plugin_runtime_hook_contracts.reference_custom_rollout_log",
        "slime/ray/rollout.py",
        "custom_log_func(rollout_id, args, samples, rollout_extra_metrics, rollout_time)",
        ("rollout_id", "args", "samples", "rollout_extra_metrics", "rollout_time"),
        invoke_custom_rollout_log,
    ),
    HookCase(
        "custom_eval_rollout_log",
        "CUSTOM_EVAL_ROLLOUT_LOG_FUNCTION_PATH",
        "plugin_contracts.test_plugin_runtime_hook_contracts.reference_custom_eval_rollout_log",
        "slime/ray/rollout.py",
        "custom_log_func(rollout_id, args, data, extra_metrics)",
        ("rollout_id", "args", "data", "extra_metrics"),
        invoke_custom_eval_rollout_log,
    ),
    HookCase(
        "custom_reward_post_process",
        "CUSTOM_REWARD_POST_PROCESS_PATH",
        "plugin_contracts.test_plugin_runtime_hook_contracts.reference_reward_post_process",
        "slime/ray/rollout.py",
        "self.custom_reward_post_process_func(self.args, samples)",
        ("args", "samples"),
        invoke_reward_post_process,
    ),
    HookCase(
        "custom_convert_samples_to_train_data",
        "CUSTOM_CONVERT_SAMPLES_TO_TRAIN_DATA_PATH",
        "plugin_contracts.test_plugin_runtime_hook_contracts.reference_convert_samples_to_train_data",
        "slime/ray/rollout.py",
        "self.custom_convert_samples_to_train_data_func(self.args, samples)",
        ("args", "samples"),
        invoke_convert_samples_to_train_data,
    ),
    HookCase(
        "rollout_data_postprocess",
        "ROLLOUT_DATA_POSTPROCESS_PATH",
        "plugin_contracts.test_plugin_runtime_hook_contracts.reference_rollout_data_postprocess",
        "slime/backends/megatron_utils/actor.py",
        "self.rollout_data_postprocess(self.args)",
        ("args",),
        invoke_rollout_data_postprocess,
    ),
]


@pytest.mark.parametrize("case", HOOK_CASES, ids=[case.name for case in HOOK_CASES])
def test_runtime_hook_callsite_is_stable(case: HookCase):
    assert case.runtime_marker in Path(case.source_path).read_text()


@pytest.mark.parametrize("case", HOOK_CASES, ids=[case.name for case in HOOK_CASES])
def test_runtime_hook_path_aligns_with_expected_format(case: HookCase):
    fn = load_function(get_contract_path(case.env_key, case.default_path))
    assert tuple(inspect.signature(fn).parameters) == case.expected_params
    case.invoke(fn)


if __name__ == "__main__":
    run_contract_test_file()