#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Tests for TemporalSmoother module.
"""

import sys
from pathlib import Path

sys.path.insert(0, str(Path(__file__).parent.parent))

import numpy as np
import torch
import pytest

from action_dispatch.temporal_smoother import (
    TemporalSmoother,
    TemporalSmootherConfig,
    TemporalSmootherManager,
)


class TestTemporalSmootherConfig:
    def test_default_config(self):
        config = TemporalSmootherConfig()
        assert config.enabled is True
        assert config.chunk_size == 100
        assert config.temporal_ensemble_coeff == 0.01
        assert config.device is None

    def test_custom_config(self):
        config = TemporalSmootherConfig(
            enabled=False,
            chunk_size=50,
            temporal_ensemble_coeff=0.05,
            device='cuda:0',
        )
        assert config.enabled is False
        assert config.chunk_size == 50
        assert config.temporal_ensemble_coeff == 0.05
        assert config.device == 'cuda:0'

    def test_invalid_chunk_size(self):
        with pytest.raises(ValueError):
            TemporalSmootherConfig(chunk_size=0)
        with pytest.raises(ValueError):
            TemporalSmootherConfig(chunk_size=-1)


class TestTemporalSmoother:
    def test_basic_update_and_get(self):
        config = TemporalSmootherConfig(chunk_size=10)
        smoother = TemporalSmoother(config)
        
        actions = np.random.randn(10, 7)
        smoother.update(actions, 0)
        
        assert smoother.plan_length == 10
        
        for i in range(10):
            action = smoother.get_next_action()
            assert action.shape == (7,)
            assert smoother.plan_length == 9 - i

    def test_disabled_smoothing(self):
        config = TemporalSmootherConfig(enabled=False, chunk_size=10)
        smoother = TemporalSmoother(config)
        
        actions1 = np.ones((10, 7))
        smoother.update(actions1, 0)
        
        for _ in range(5):
            smoother.get_next_action()
        
        assert smoother.plan_length == 5
        
        plan_len_at_inference_start = 5
        for _ in range(3):
            smoother.get_next_action()
        actions_executed = plan_len_at_inference_start - smoother.plan_length
        
        actions2 = np.zeros((10, 7))
        smoother.update(actions2, actions_executed)
        
        assert smoother.plan_length == 7
        
        for _ in range(7):
            action = smoother.get_next_action()
            np.testing.assert_array_almost_equal(action, np.zeros(7))

    def test_cross_frame_smoothing(self):
        config = TemporalSmootherConfig(enabled=True, chunk_size=10, temporal_ensemble_coeff=0.01)
        smoother = TemporalSmoother(config)
        
        actions1 = np.ones((10, 7)) * 1.0
        smoother.update(actions1, 0)
        
        for _ in range(3):
            smoother.get_next_action()
        
        assert smoother.plan_length == 7
        
        plan_len_at_inference_start = 7
        for _ in range(2):
            smoother.get_next_action()
        actions_executed = plan_len_at_inference_start - smoother.plan_length
        
        actions2 = np.ones((10, 7)) * 2.0
        smoother.update(actions2, actions_executed)
        
        assert smoother.plan_length == 8
        
        first_action = smoother.peek_next_action()
        assert first_action is not None
        
        assert not np.allclose(first_action.numpy(), np.ones(7) * 1.0)
        assert not np.allclose(first_action.numpy(), np.ones(7) * 2.0)

    def test_tensor_input(self):
        config = TemporalSmootherConfig(chunk_size=10)
        smoother = TemporalSmoother(config)
        
        actions = torch.randn(10, 7)
        smoother.update(actions, 0)
        
        assert smoother.plan_length == 10
        
        action = smoother.get_next_action()
        assert isinstance(action, torch.Tensor)
        assert action.shape == (7,)

    def test_reset(self):
        config = TemporalSmootherConfig(chunk_size=10)
        smoother = TemporalSmoother(config)
        
        actions = np.random.randn(10, 7)
        smoother.update(actions, 0)
        
        assert smoother.plan_length == 10
        
        smoother.reset()
        
        assert smoother.plan_length == 0
        assert smoother._smoothed_actions is None

    def test_empty_input(self):
        config = TemporalSmootherConfig(chunk_size=10)
        smoother = TemporalSmoother(config)
        
        actions = np.array([]).reshape(0, 7)
        result = smoother.update(actions, 0)
        
        assert result == 0
        assert smoother.plan_length == 0

    def test_get_next_action_raises_on_empty(self):
        config = TemporalSmootherConfig(chunk_size=10)
        smoother = TemporalSmoother(config)
        
        with pytest.raises(IndexError):
            smoother.get_next_action()

    def test_1d_input_reshaped(self):
        config = TemporalSmootherConfig(chunk_size=10)
        smoother = TemporalSmoother(config)
        
        actions = np.random.randn(7)
        smoother.update(actions, 0)
        
        assert smoother.plan_length == 1
        
        action = smoother.get_next_action()
        assert action.shape == (7,)


class TestTemporalSmootherManager:
    def test_manager_basic(self):
        manager = TemporalSmootherManager(
            enabled=True,
            chunk_size=10,
            temporal_ensemble_coeff=0.01,
        )
        
        assert manager.is_enabled is True
        assert manager.plan_length == 0
        
        actions = np.random.randn(10, 7)
        manager.update(actions, 0)
        
        assert manager.plan_length == 10

    def test_manager_toggle(self):
        manager = TemporalSmootherManager(enabled=True, chunk_size=10)
        
        assert manager.is_enabled is True
        
        manager.set_enabled(False)
        assert manager.is_enabled is False
        
        manager.set_enabled(True)
        assert manager.is_enabled is True

    def test_manager_peek(self):
        manager = TemporalSmootherManager(enabled=True, chunk_size=10)
        
        actions = np.random.randn(10, 7)
        manager.update(actions, 0)
        
        peeked = manager.peek_next_action()
        assert peeked is not None
        
        assert manager.plan_length == 10
        
        gotten = manager.get_next_action()
        assert manager.plan_length == 9


class TestSmoothingFormula:
    def test_weight_calculation(self):
        config = TemporalSmootherConfig(chunk_size=5, temporal_ensemble_coeff=0.0)
        smoother = TemporalSmoother(config)
        
        expected_weights = torch.ones(5)
        np.testing.assert_array_almost_equal(
            smoother._weights.numpy(), expected_weights.numpy()
        )

    def test_positive_coeff_weights(self):
        config = TemporalSmootherConfig(chunk_size=5, temporal_ensemble_coeff=0.1)
        smoother = TemporalSmoother(config)
        
        assert smoother._weights[0] > smoother._weights[4]
        assert smoother._weights[0] == 1.0

    def test_cumulative_weights(self):
        config = TemporalSmootherConfig(chunk_size=5, temporal_ensemble_coeff=0.0)
        smoother = TemporalSmoother(config)
        
        expected_cumsum = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0])
        np.testing.assert_array_almost_equal(
            smoother._weights_cumsum.numpy(), expected_cumsum.numpy()
        )


if __name__ == '__main__':
    pytest.main([__file__, '-v'])