"""test stress detect."""
import unittest
import pytest
import mindspore as ms
from mindformers.core.callback import StressDetectCallBack
from mindformers.tools.logger import get_logger
ms.set_device(device_target='CPU')
PASS_CODE = 0
VOLTAGE_ERROR_CODE = 574007
OTHER_ERROR_CODE = 174003
logger = get_logger()
class TestStressDetectCallBack(unittest.TestCase):
"""A test class for testing StressDetectCallBack."""
def setUp(self):
self.detection_interval = 10
self.num_detections = 1
self.dataset_size = 1024
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_log_stress_detect_result_passed(self):
"""
Feature: StressDetectCallBack
Description: Test StressDetectCallBack log_stress_detect_result
Expectation: No Exception
"""
detect_ret_list = [PASS_CODE]
callback = StressDetectCallBack(
detection_interval=self.detection_interval,
num_detections=self.num_detections,
dataset_size=self.dataset_size
)
with self.assertLogs(logger, level='INFO') as log:
callback.log_stress_detect_result(detect_ret_list)
target = "Stress detection passed"
self.assertTrue(
any(target in message for message in log.output),
msg=f"Log should contain '{target}' but was: {log.output}"
)
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_log_stress_detect_result_voltage_error(self):
"""
Feature: StressDetectCallBack
Description: Test StressDetectCallBack log_stress_detect_result
Expectation: RuntimeError
"""
detect_ret_list = [VOLTAGE_ERROR_CODE]
callback = StressDetectCallBack(
detection_interval=self.detection_interval,
num_detections=self.num_detections,
dataset_size=self.dataset_size
)
with self.assertRaises(RuntimeError) as context:
callback.log_stress_detect_result(detect_ret_list)
self.assertIn(f"Voltage recovery failed with error code: {VOLTAGE_ERROR_CODE}", str(context.exception))
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_log_stress_detect_result_other_error(self):
"""
Feature: StressDetectCallBack
Description: Test StressDetectCallBack log_stress_detect_result
Expectation: No Exception
"""
detect_ret_list = [OTHER_ERROR_CODE]
callback = StressDetectCallBack(
detection_interval=self.detection_interval,
num_detections=self.num_detections,
dataset_size=self.dataset_size
)
with self.assertLogs(logger, level='WARNING') as log:
callback.log_stress_detect_result(detect_ret_list)
target = f"Stress detection failed with error code: {OTHER_ERROR_CODE}"
self.assertTrue(
any(target in message for message in log.output),
msg=f"Log should contain '{target}' but was: {log.output}"
)