# Copyright (c) 2025 Huawei Technologies Co., Ltd.
# openFuyao is licensed under Mulan PSL v2.
# You can use this software according to the terms and conditions of the Mulan PSL v2.
# You may obtain a copy of Mulan PSL v2 at:
#          http://license.coscl.org.cn/MulanPSL2
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
# See the Mulan PSL v2 for more details.

"""
This module contains unit tests for NPU detector rules.
"""

from unittest import TestCase
from unittest.mock import patch

from hardware_diagnosis.event.base import FaultType
from hardware_diagnosis.detector.npu.rules import (
    HBMErrorRule, DDRErrorRule, NPUErrorRules, NetworkErrorRule
)
from hardware_diagnosis.event.hardware_event import NPUHardwareEvent


class TestErrorRule(TestCase):
    """
    TestErrorRule verifies the detection of HBM-related errors, such as over-temperature
    and errors from event logs. It checks if the rule correctly creates NPUHardwareEvent
    instances with the appropriate details.
    """
    def test_hbmerror_rule(self):
        """
        Tests the HBMErrorRule's detection of HBM errors, including over-temperature and
        error message events. The test ensures that the error code, fault type, and details
        are correctly set in the resulting NPUHardwareEvent.
        """
        data = {"collector.npu.hbm": {"0": {"card_id": 0, "device_id": 0,
                                            "node_name": "node",
                                            "npu_chip_info_hbm_total_memory": 100,
                                            "npu_chip_info_hbm_frequency": 200,
                                            "npu_chip_info_hbm_used_memory": 10,
                                            "npu_chip_info_hbm_temperature": 100,
                                            "npu_chip_info_hbm_bandwidth_utilization": 2,
                                            "npu_chip_info_hbm_utilization": 2,
                                            "errors": [{"error_code": 8008, "error_message": "mock error"}]}}}
        out = HBMErrorRule().detect(data)
        temp_want = NPUHardwareEvent(error_code=hex(100), logic_id=0, card_id=0, device_id=0, node_name="node",
                                     fault_type=FaultType.HARDWARE, hardware_type="NPU_HBM", severity=1,
                                     event_id="1a4d4730-ebdb-4d29-b097-05430a08f82a",
                                     details="HBM temperature is over threshold")
        func_want = NPUHardwareEvent(error_code="8008", logic_id=0, card_id=0, device_id=0, node_name="node",
                                     fault_type=FaultType.HARDWARE, hardware_type="NPU_HBM", severity=1,
                                     event_id="9fb8f9e3-7167-42fe-a10e-0ccbde86c14a", details="mock error")
        self.assertEqual(out[0].error_code, temp_want.error_code)
        self.assertEqual(out[0].details, temp_want.details)
        self.assertEqual(out[1].error_code, func_want.error_code)
        self.assertEqual(out[1].details, func_want.details)

    def test_ddrerror_rule(self):
        """
        Tests the DDRErrorRule's detection of DDR errors, specifically checking if
        error codes and messages from DDR errors are correctly processed and reported.
        """
        data = {"collector.npu.ddr": {
            "0": {"card_id": 0, "device_id": 0, "node_name": "master",
                  "npu_chip_info_total_memory": 300, "npu_chip_info_used_memory": 30,
                  "npu_chip_info_memory_frequency": 39, "npu_chip_info_memory_utilization": 40,
                  "errors": [{"error_code": 1111, "error_message": "mock DDR error"}]}}}
        out = DDRErrorRule().detect(data)
        want = NPUHardwareEvent(error_code="1111", details="mock DDR error", card_id=0, device_id=0, logic_id=0,
                                node_name="master", fault_type=FaultType.HARDWARE, event_id="", hardware_type="NPU_DDR")
        self.assertEqual(out[0].error_code, want.error_code)
        self.assertEqual(out[0].details, want.details)

    @patch("hardware_diagnosis.detector.npu.rules.get_device_error_code_string")
    def test_npuerror_rules(self, mock_info):
        """
        Tests the NPUErrorRules' detection of NPU-related errors. This test ensures that errors
        such as system memory issues are detected and reported with the appropriate error codes
        and details.
        """
        data = {"collector.npu.npu": {
            "0": {"card_id": 0, "device_id": 0, "node_name": "master",
                  "npu_chip_info_utilization": 30, "npu_chip_info_temperature": 50,
                  "npu_chip_info_health_status": 0, "npu_chip_info_error_code": [int(0x8C2FA001)],
                  "npu_chip_info_network_status": 0,
                  "errors": [{"error_code": 8001, "error_message": "mock NPU error"}]}}}
        mock_info.return_value = ("system out of memory", None, None)
        out = NPUErrorRules().detect(data)
        chip_want = NPUHardwareEvent(hardware_type="NPU", error_code=hex(0x8C2FA001), device_id=0,
                                     card_id=0, logic_id=0, node_name="master",
                                     fault_type=FaultType.HARDWARE,
                                     details="system out of memory")
        func_want = NPUHardwareEvent(hardware_type="NPU", error_code="8001", device_id=0,
                                     card_id=0, logic_id=0, node_name="master",
                                     fault_type=FaultType.HARDWARE,
                                     details="mock NPU error")
        self.assertEqual(out[0].error_code, chip_want.error_code)
        self.assertEqual(out[0].details, chip_want.details)
        mock_info.assert_called_once()
        self.assertEqual(out[1].error_code, func_want.error_code)
        self.assertEqual(out[1].details, func_want.details)

    def test_network_error_rule(self):
        """
        Tests the NetworkErrorRule's detection of network-related errors, specifically
        verifying that network errors are correctly processed and reported with appropriate
        error codes and details.
        """
        data = {"collector.npu.network": {"0": {"card_id": 0,
                                                "device_id": 0, "phy_id": 3,
                                                "node_name": "master",
                                                "npu_chip_info_bandwidth_tx": 300,
                                                "npu_chip_info_bandwidth_rx": 300,
                                                "npu_chip_info_link_status": 0,
                                                "errors": [{"error_code": 8001,
                                                            "error_message": "mock network error"}]}}}

        out = NetworkErrorRule().detect(data)
        want = NPUHardwareEvent(hardware_type="NPU_Network", error_code="8001",
                                device_id=0, logic_id=0, card_id=0,
                                node_name="master",
                                fault_type=FaultType.HARDWARE, details="mock network error")
        self.assertEqual(out[0].error_code, want.error_code)
        self.assertEqual(out[0].details, want.details)