from __future__ import annotations

import json
import sys

import numpy as np
import pytest
import torch

from inference_service.core.compiled_policy import (
    ACTCompiledAdapter,
    CompiledPolicyWrapper,
    OMRuntimeSession,
    PI05CompiledAdapter,
    PI05OMRuntimeSession,
    PI05RuntimeInputs,
    SD3403RuntimeSession,
    create_compiled_model_adapter,
    create_runtime_session,
    load_compiled_manifest,
    resolve_om_model_path,
    resolve_pi05_om_paths,
)


class FakeRuntimeSession:
    def __init__(self, output=None):
        self.output = output if output is not None else np.zeros((1, 2, 6), dtype=np.float32)
        self.loaded = None
        self.inputs = None

    def load(self, policy_path, config, device):
        self.loaded = (policy_path, config, device)

    def execute(self, inputs):
        self.inputs = inputs
        return [self.output]

    def release(self):
        pass


def _act_config(**updates):
    config = {
        "type": "act",
        "chunk_size": 2,
        "input_features": {
            "observation.state": {"shape": [3]},
            "observation.images.side": {"shape": [3, 4, 5]},
            "observation.images.gripper": {"shape": [3, 4, 5]},
        },
        "output_features": {"action": {"shape": [6]}},
    }
    config.update(updates)
    return config


def _write_policy(tmp_path, config):
    (tmp_path / "config.json").write_text(json.dumps(config), encoding="utf-8")


def _write_manifest(tmp_path, manifest):
    (tmp_path / "config.om.json").write_text(json.dumps(manifest), encoding="utf-8")


def _pi05_config(**updates):
    config = {
        "type": "pi05",
        "chunk_size": 50,
        "max_action_dim": 32,
        "num_inference_steps": 10,
        "input_features": {
            "observation.images.front": {"type": "VISUAL", "shape": [3, 224, 224]},
            "observation.language.tokens": {"shape": [48]},
            "observation.language.attention_mask": {"shape": [48]},
        },
        "output_features": {"action": {"shape": [6]}},
    }
    config.update(updates)
    return config


def test_adapter_selection_from_config_type():
    adapter = create_compiled_model_adapter(_act_config(), "rknn")

    assert isinstance(adapter, ACTCompiledAdapter)
    assert adapter.policy_type == "act"
    assert adapter.uses_action_chunking is True

    pi05_adapter = create_compiled_model_adapter(_pi05_config(), "ascend_om")
    assert isinstance(pi05_adapter, PI05CompiledAdapter)
    assert pi05_adapter.policy_type == "pi05"
    assert pi05_adapter.uses_action_chunking is True


def test_adapter_rejects_missing_and_unsupported_type():
    with pytest.raises(ValueError, match="missing required type"):
        create_compiled_model_adapter({"input_features": {}}, "rknn")

    with pytest.raises(ValueError, match="does not support policy type"):
        create_compiled_model_adapter({"type": "diffusion"}, "rknn")

    with pytest.raises(ValueError, match="does not support PI05"):
        create_compiled_model_adapter(_pi05_config(), "rknn")


def test_act_input_mapping_uses_declared_order_and_camera_names():
    adapter = ACTCompiledAdapter.from_config(_act_config(), "rknn")

    inputs = adapter.prepare_inputs(
        {
            "observation.state": torch.full((3,), 1.0),
            "observation.images.side": torch.full((3, 4, 5), 2.0),
            "observation.images.gripper": torch.full((3, 4, 5), 3.0),
        }
    )

    assert [arr.shape for arr in inputs] == [(1, 3), (1, 3, 4, 5), (1, 3, 4, 5)]
    assert float(inputs[0][0, 0]) == 1.0
    assert float(inputs[1][0, 0, 0, 0]) == 2.0
    assert float(inputs[2][0, 0, 0, 0]) == 3.0


def test_act_input_mapping_rejects_missing_tensor():
    adapter = ACTCompiledAdapter.from_config(_act_config(), "rknn")

    with pytest.raises(KeyError, match="observation.images.gripper"):
        adapter.prepare_inputs(
            {
                "observation.state": torch.ones(3),
                "observation.images.side": torch.ones(3, 4, 5),
            }
        )


def test_act_decodes_om_action_chunk():
    adapter = ACTCompiledAdapter.from_config(_act_config(), "ascend_om")
    action = adapter.decode_outputs([np.arange(12, dtype=np.float32)], torch.device("cpu"))

    assert action.shape == (2, 6)


def test_act_decodes_sd3403_crop_and_updates_chunk_size():
    adapter = ACTCompiledAdapter.from_config(_act_config(chunk_size=1), "ascend_om_3403")
    action = adapter.decode_outputs([np.arange(16, dtype=np.float32)], torch.device("cpu"))

    assert action.shape == (2, 6)
    assert adapter.get_chunk_size() == 2


def test_pi05_adapter_prepares_runtime_inputs_and_slices_padding():
    adapter = PI05CompiledAdapter.from_config(_pi05_config(), "ascend_om")

    inputs = adapter.prepare_inputs(
        {
            "observation.images.front": torch.full((1, 3, 224, 224), 1.0),
            "observation.language.tokens": torch.arange(48).reshape(1, 48),
            "observation.language.attention_mask": torch.ones(1, 48, dtype=torch.bool),
            "_noise": torch.zeros(1, 50, 32),
        }
    )

    assert isinstance(inputs, PI05RuntimeInputs)
    assert inputs.images[0].shape == (1, 3, 224, 224)
    assert inputs.tokens.dtype == np.int64
    assert inputs.masks.dtype == np.bool_
    assert inputs.noise.shape == (1, 50, 32)

    action = adapter.decode_outputs(torch.zeros(1, 50, 32), torch.device("cpu"))

    assert action.shape == (50, 6)
    assert adapter.get_chunk_size() == 50


def test_compiled_wrapper_reports_metadata_and_runtime_device(tmp_path):
    _write_policy(tmp_path, _act_config(input_features={"observation.state": {"shape": [3]}}))
    runtime = FakeRuntimeSession(output=np.arange(12, dtype=np.float32))
    wrapper = CompiledPolicyWrapper("rknn", runtime_session=runtime)
    device = torch.device("cpu")

    wrapper.load(str(tmp_path), device)
    action = wrapper.infer({"observation.state": torch.ones(3)})

    assert runtime.loaded[0] == str(tmp_path)
    assert runtime.loaded[1]["type"] == "act"
    assert runtime.loaded[2] == device
    assert runtime.inputs[0].shape == (1, 3)
    assert action.shape == (2, 6)
    assert wrapper.policy_type == "act"
    assert wrapper.backend_type == "rknn"
    assert wrapper.uses_action_chunking is True


def test_compiled_wrapper_requires_config_json(tmp_path):
    wrapper = CompiledPolicyWrapper("rknn", runtime_session=FakeRuntimeSession())

    with pytest.raises(FileNotFoundError, match="config.json"):
        wrapper.load(str(tmp_path), torch.device("cpu"))


def test_runtime_dependencies_import_lazily():
    before = set(sys.modules)

    OMRuntimeSession()
    SD3403RuntimeSession()
    PI05OMRuntimeSession()

    imported = set(sys.modules) - before
    assert "acl" not in imported
    assert "rknnlite.api" not in imported


def test_ascend_runtime_session_selects_pi05_from_config():
    assert isinstance(create_runtime_session("ascend_om", _act_config()), OMRuntimeSession)
    assert isinstance(create_runtime_session("ascend_om", _pi05_config()), PI05OMRuntimeSession)


def test_manifest_resolves_single_om_policy_role(tmp_path):
    model = tmp_path / "om" / "act.om"
    model.parent.mkdir()
    model.write_bytes(b"om")
    _write_policy(tmp_path, _act_config())
    _write_manifest(
        tmp_path,
        {
            "schema_version": 1,
            "policy_type": "act",
            "backend": "ascend_om",
            "artifact_dir": "om",
            "artifacts": {"policy": "act.om"},
            "execution": ["policy"],
        },
    )

    manifest = load_compiled_manifest(str(tmp_path), "ascend_om", "act")

    assert resolve_om_model_path(str(tmp_path), _act_config(), manifest) == model.resolve()


def test_manifest_resolves_pi05_roles_and_execution(tmp_path):
    vlm = tmp_path / "om" / "vlm.om"
    ae = tmp_path / "om" / "action_expert.om"
    vlm.parent.mkdir()
    vlm.write_bytes(b"vlm")
    ae.write_bytes(b"ae")
    _write_policy(tmp_path, _pi05_config())
    _write_manifest(
        tmp_path,
        {
            "schema_version": 1,
            "policy_type": "pi05",
            "backend": "ascend_om",
            "artifact_dir": "om",
            "artifacts": {
                "vlm": "vlm.om",
                "action_expert": "action_expert.om",
            },
            "execution": ["vlm", "action_expert"],
        },
    )

    manifest = load_compiled_manifest(str(tmp_path), "ascend_om", "pi05")

    assert resolve_pi05_om_paths(str(tmp_path), _pi05_config(), manifest) == (vlm.resolve(), ae.resolve())


def test_manifest_rejects_wrong_backend_policy_and_execution(tmp_path):
    model = tmp_path / "model.om"
    model.write_bytes(b"om")
    _write_manifest(
        tmp_path,
        {
            "schema_version": 1,
            "policy_type": "pi05",
            "backend": "rknn",
            "artifacts": {"policy": "model.om"},
        },
    )
    with pytest.raises(ValueError, match="does not match requested backend"):
        load_compiled_manifest(str(tmp_path), "ascend_om", "pi05")

    _write_manifest(
        tmp_path,
        {
            "schema_version": 1,
            "policy_type": "act",
            "backend": "ascend_om",
            "artifacts": {"policy": "model.om"},
        },
    )
    with pytest.raises(ValueError, match="does not match config type"):
        load_compiled_manifest(str(tmp_path), "ascend_om", "pi05")

    _write_manifest(
        tmp_path,
        {
            "schema_version": 1,
            "policy_type": "pi05",
            "backend": "ascend_om",
            "artifacts": {"vlm": "model.om", "action_expert": "model.om"},
            "execution": ["action_expert", "vlm"],
        },
    )
    manifest = load_compiled_manifest(str(tmp_path), "ascend_om", "pi05")
    with pytest.raises(ValueError, match="execution must be"):
        resolve_pi05_om_paths(str(tmp_path), _pi05_config(), manifest)


def test_manifest_required_for_om_resolution(tmp_path):
    _write_policy(tmp_path, _act_config())

    with pytest.raises(FileNotFoundError, match="config.om.json"):
        resolve_om_model_path(str(tmp_path), _act_config())


def test_pi05_runtime_builds_prefix_mask_and_forwards():
    class FakeModel:
        prefix_seq_len = 52

        def __init__(self):
            self.forward_args = None

        def forward(self, images, tokens, masks, prefix_mask, noise=None):
            self.forward_args = (images, tokens, masks, prefix_mask, noise)
            return torch.zeros(1, 50, 32)

    session = PI05OMRuntimeSession()
    session._model = FakeModel()
    inputs = PI05RuntimeInputs(
        images=[np.ones((1, 3, 224, 224), dtype=np.float32)],
        tokens=np.ones((1, 48), dtype=np.int64),
        masks=np.ones((1, 48), dtype=np.bool_),
        noise=np.zeros((1, 50, 32), dtype=np.float32),
    )

    output = session.execute(inputs)

    _, _, _, prefix_mask, noise = session._model.forward_args
    assert output.shape == (1, 50, 32)
    assert prefix_mask.shape == (1, 1, 52, 52)
    assert noise is inputs.noise


def test_sd3403_runtime_uses_worker_public_array_api():
    class FakeWorker:
        def __init__(self):
            self.inputs = None
            self.closed = False

        def execute_arrays(self, inputs):
            self.inputs = inputs
            return np.arange(16, dtype=np.float32)

        def close(self):
            self.closed = True

    session = SD3403RuntimeSession()
    session._worker = FakeWorker()
    inputs = [np.ones((1, 3), dtype=np.float32)]

    outputs = session.execute(inputs)

    assert session._worker.inputs is inputs
    assert outputs[0].shape == (16,)