# Owner(s): ["module: tests"]
import torch_npu
from torch_npu.testing.testcase import run_tests, TestCase
from torch_npu.utils.affinity import _reset_thread_affinity, _set_thread_affinity


class TestAffinity(TestCase):
    def test_reset_thread_affinity(self):
        original_func = torch_npu._C._npu_reset_thread_affinity
        call_count = 0

        def mock_npu_reset_thread_affinity():
            nonlocal call_count
            call_count += 1

        torch_npu._C._npu_reset_thread_affinity = mock_npu_reset_thread_affinity
        try:
            _reset_thread_affinity()
            self.assertEqual(call_count, 1)
        finally:
            torch_npu._C._npu_reset_thread_affinity = original_func

    def test_set_thread_affinity_invalid_length(self):
        with self.assertRaises(ValueError) as context:
            _set_thread_affinity([1, 2, 3])
        self.assertIn("Invalid core range", str(context.exception))

        with self.assertRaises(ValueError) as context:
            _set_thread_affinity([])
        self.assertIn("Invalid core range", str(context.exception))

    def test_set_thread_affinity_negative_values(self):
        with self.assertRaises(ValueError) as context:
            _set_thread_affinity([-1, 5])
        self.assertIn("Invalid core range", str(context.exception))

        with self.assertRaises(ValueError) as context:
            _set_thread_affinity([2, -3])
        self.assertIn("Invalid core range", str(context.exception))

    def test_set_thread_affinity_valid_range(self):
        original_func = torch_npu._C._npu_set_thread_affinity
        call_args = []

        def mock_npu_set_thread_affinity(cores):
            nonlocal call_args
            call_args = cores

        torch_npu._C._npu_set_thread_affinity = mock_npu_set_thread_affinity
        try:
            _set_thread_affinity([2, 5])
            self.assertEqual(call_args, [2, 3, 4, 5])
            _set_thread_affinity([[2, 5], [7, 9]])
            self.assertEqual(call_args, [2, 3, 4, 5, 7, 8, 9])
            _set_thread_affinity([[2, 7], [4, 9]])
            self.assertEqual(call_args, [2, 3, 4, 5, 6, 7, 8, 9])
        finally:
            torch_npu._C._npu_set_thread_affinity = original_func

    def test_set_thread_affinity_none(self):
        original_func = torch_npu._C._npu_set_thread_affinity
        call_args = []

        def mock_npu_set_thread_affinity(start, end):
            call_args.append((start, end))

        torch_npu._C._npu_set_thread_affinity = mock_npu_set_thread_affinity
        try:
            _set_thread_affinity(None)
            self.assertEqual(call_args, [(-1, -1)])
        finally:
            torch_npu._C._npu_set_thread_affinity = original_func


if __name__ == "__main__":
    run_tests()