from __future__ import annotations

import asyncio
import inspect
import types
from contextlib import contextmanager

import pytest

try:
    from ._shared import get_contract_path, install_paths, install_stubs, run_contract_test_for_file
except ImportError:
    try:
        from plugin_contracts._shared import (
            get_contract_path,
            install_paths,
            install_stubs,
            run_contract_test_for_file,
        )
    except ImportError:
        from _shared import get_contract_path, install_paths, install_stubs, run_contract_test_for_file

install_paths()
install_stubs(with_sglang_router=True, with_transformers=True)

NUM_GPUS = 0
REFERENCE_CUSTOM_GENERATE_PATH = "plugin_contracts.test_plugin_generate_contracts.custom_generate"
REFERENCE_CUSTOM_GENERATE_WITH_EVAL_PATH = (
    "plugin_contracts.test_plugin_generate_contracts.custom_generate_with_evaluation"
)

from slime.rollout.sglang_rollout import generate_and_rm
from slime.utils.misc import load_function
from slime.utils.types import Sample


def run_contract_test_file() -> None:
    run_contract_test_for_file(__file__, path_args=["custom-generate-function-path"])


def make_args(**overrides):
    class Args:
        partial_rollout = False
        mask_offpolicy_in_partial_rollout = False
        group_rm = False
        custom_generate_function_path = None
        sglang_enable_deterministic_inference = False
        rollout_seed = 7
        n_samples_per_prompt = 2

    args = Args()
    for key, value in overrides.items():
        setattr(args, key, value)
    return args


class FakeGenerateState:
    def __init__(self, args) -> None:
        self.args = args
        self.semaphore = types.SimpleNamespace(__aenter__=None)
        self.pendings = set()
        self.remaining_batch_size = 0
        self.aborted = False
        self.group_sampling_seeds = [args.rollout_seed + i for i in range(args.n_samples_per_prompt)]

    @contextmanager
    def dp_rank_context(self):
        yield 0


async def custom_generate(args, sample: Sample, sampling_params: dict):
    sample.tokens = [11, 12, 13]
    sample.response = "generated"
    sample.response_length = len(sample.tokens)
    sample.reward = 0.25
    sample.status = Sample.Status.COMPLETED
    return sample


async def custom_generate_with_evaluation(args, sample: Sample, sampling_params: dict, evaluation: bool = False):
    sample.tokens = [21, 22]
    sample.response = "eval-generated" if evaluation else "train-generated"
    sample.response_length = len(sample.tokens)
    sample.reward = 0.5 if evaluation else 0.75
    sample.status = Sample.Status.COMPLETED
    sample.metadata["evaluation"] = evaluation
    return sample


def assert_sample_contract(sample: Sample) -> None:
    assert isinstance(sample, Sample)
    assert isinstance(sample.tokens, list)
    assert isinstance(sample.response, str)
    assert isinstance(sample.response_length, int)
    assert sample.reward is not None


def assert_custom_generate_signature_matches_expected(fn) -> None:
    params = tuple(inspect.signature(fn).parameters)
    assert params[:3] == ("args", "sample", "sampling_params")


class _DummySemaphore:
    async def __aenter__(self):
        return None

    async def __aexit__(self, exc_type, exc, tb):
        return False


class _PatchedGenerateState(FakeGenerateState):
    def __init__(self, args):
        super().__init__(args)
        self.semaphore = _DummySemaphore()


@pytest.fixture
def patch_generate_state(monkeypatch):
    """Patch GenerateState with a test-safe variant; returns the sglang_rollout module."""
    from slime.rollout import sglang_rollout

    monkeypatch.setattr(sglang_rollout, "GenerateState", _PatchedGenerateState)
    return sglang_rollout


def test_generate_and_rm_default_generate_branch_is_stable(patch_generate_state, monkeypatch):
    sglang_rollout = patch_generate_state

    async def official_default_generate(args, sample: Sample, sampling_params: dict):
        sample.tokens = [31, 32]
        sample.response = "default-generate"
        sample.response_length = 2
        sample.reward = 1.0
        sample.status = Sample.Status.COMPLETED
        return sample

    monkeypatch.setattr(sglang_rollout, "generate", official_default_generate)

    result = asyncio.run(
        generate_and_rm(
            make_args(custom_generate_function_path=None),
            Sample(index=0, prompt="prompt"),
            sampling_params={"temperature": 0.3},
            evaluation=False,
        )
    )
    assert_sample_contract(result)
    assert result.response == "default-generate"


def test_generate_and_rm_prefers_per_sample_generate_function(patch_generate_state):
    args = make_args(custom_generate_function_path=REFERENCE_CUSTOM_GENERATE_PATH)
    sample = Sample(index=0, prompt="prompt", generate_function_path=REFERENCE_CUSTOM_GENERATE_WITH_EVAL_PATH)
    result = asyncio.run(generate_and_rm(args, sample, sampling_params={"temperature": 0.3}, evaluation=True))
    assert_sample_contract(result)
    assert result.metadata["evaluation"] is True


def test_custom_generate_function_path_supports_user_override(patch_generate_state):
    custom_generate_path = get_contract_path(
        "CUSTOM_GENERATE_FUNCTION_PATH",
        REFERENCE_CUSTOM_GENERATE_PATH,
    )
    assert_custom_generate_signature_matches_expected(load_function(custom_generate_path))
    result = asyncio.run(
        generate_and_rm(
            make_args(custom_generate_function_path=custom_generate_path),
            Sample(index=0, prompt="prompt"),
            sampling_params={"temperature": 0.3},
            evaluation=False,
        )
    )
    assert_sample_contract(result)


def test_generate_and_rm_group_rm_accepts_list_result_from_custom_generate(patch_generate_state, monkeypatch):
    sglang_rollout = patch_generate_state

    async def custom_generate_list(args, sample: Sample, sampling_params: dict):
        sample.status = Sample.Status.COMPLETED
        sibling = Sample(index=1, prompt="prompt-1", status=Sample.Status.COMPLETED)
        return [sample, sibling]

    monkeypatch.setattr(sglang_rollout, "load_function", lambda _path: custom_generate_list)

    result = asyncio.run(
        generate_and_rm(
            make_args(custom_generate_function_path="plugin_contracts.fake_generate", group_rm=True),
            Sample(index=0, prompt="prompt-0"),
            sampling_params={"temperature": 0.3},
            evaluation=False,
        )
    )

    assert isinstance(result, list)
    assert len(result) == 2
    assert all(isinstance(sample, Sample) for sample in result)


if __name__ == "__main__":
    run_contract_test_file()