#!/usr/bin/env python
# coding=utf-8
# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved.
# MindIE is licensed under Mulan PSL v2.
# You can use this software according to the terms and conditions of the Mulan PSL v2.
# You may obtain a copy of Mulan PSL v2 at:
#          http://license.coscl.org.cn/MulanPSL2
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
# See the Mulan PSL v2 for more details.
import unittest
import os


from mindiesd.cache_agent import CacheAgent, CacheConfig


@unittest.skipIf(
    os.environ.get("MINDIE_TEST_MODE", "ALL") == "NPU", "Skip CPU-compatible tests when MINDIE_TEST_MODE is NPU."
)
class TestAttentionCache(unittest.TestCase):
    def test_cache_func(self):
        result = [
            0,
            1,
            2,
            3,
            4,
            5,
            6,
            7,
            8,
            9,
            5,
            6,
            7,
            8,
            9,
            5,
            6,
            7,
            8,
            9,
            20,
            21,
            22,
            23,
            24,
            20,
            21,
            22,
            23,
            24,
            30,
            31,
            32,
            33,
            34,
        ]
        steps_count = 7
        blocks_count = 5

        config = CacheConfig(
            method="attention_cache",
            blocks_count=blocks_count,
            steps_count=steps_count,
            step_start=1,
            step_end=5,
            step_interval=3,
        )
        agent = CacheAgent(config)

        def test_cache_func(i):
            return i

        for _ in range(5):  # 多次运行测试
            cache_result = []
            for step in range(steps_count):
                for block in range(blocks_count):
                    res = agent.apply(test_cache_func, step * blocks_count + block)
                    cache_result.append(res)
            self.assertEqual(cache_result, result)

    def test_cache_func_two_result(self):
        result = [
            (0, 0),
            (0, 1),
            (0, 2),
            (0, 3),
            (0, 4),
            (1, 0),
            (1, 1),
            (1, 2),
            (1, 3),
            (1, 4),
            (1, 0),
            (1, 1),
            (1, 2),
            (1, 3),
            (1, 4),
            (1, 0),
            (1, 1),
            (1, 2),
            (1, 3),
            (1, 4),
            (4, 0),
            (4, 1),
            (4, 2),
            (4, 3),
            (4, 4),
            (4, 0),
            (4, 1),
            (4, 2),
            (4, 3),
            (4, 4),
            (6, 0),
            (6, 1),
            (6, 2),
            (6, 3),
            (6, 4),
        ]
        steps_count = 7
        blocks_count = 5

        config = CacheConfig(
            method="attention_cache",
            blocks_count=blocks_count,
            steps_count=steps_count,
            step_start=1,
            step_end=5,
            step_interval=3,
        )
        agent = CacheAgent(config)

        def test_cache_func(i, j):
            return i, j

        for _ in range(5):  # 多次运行测试
            cache_result = []
            for step in range(steps_count):
                for block in range(blocks_count):
                    res = agent.apply(test_cache_func, step, block)
                    cache_result.append(res)
            self.assertEqual(cache_result, result)


if __name__ == '__main__':
    unittest.main()