"""
Test module for coverage.py from tools/convert/patch_merge/modules/
Tests the patch coverage functionality.
"""
import os
import pytest
import sys
import json
import tempfile
import shutil
from pathlib import Path
from unittest.mock import patch
import libcst as cst
from tools.convert.patch_merge.modules.coverage import (
get_printing_str,
get_debug_print_node,
check_log,
)
import json
import libcst as cst
import os
import pytest
import shutil
import sys
import tempfile
from pathlib import Path
from unittest.mock import patch
from tools.convert.patch_merge.modules.coverage import (
get_printing_str,
get_debug_print_node,
check_log,
)
project_root = Path(__file__).parent.parent.parent
sys.path.insert(0, str(project_root))
coverage_path = str(project_root / "tools" / "convert" / "patch_merge" / "modules")
if coverage_path not in sys.path:
sys.path.insert(0, coverage_path)
current_dir = Path(__file__).parent
@pytest.fixture
def mock_coverage_environment():
"""Mock coverage environment with JSON patches and log file."""
temp_dir = tempfile.mkdtemp()
patch_json = Path(temp_dir) / "test_patches.json"
test_patches = {
"megatron.test.module.function": [
{
"patch_import": "mindspeed.test.module.function",
"patch_name": "function",
"condition": False,
}
],
"megatron.test.module.another_function": [
{
"patch_import": "mindspeed.test.module.another_function",
"patch_name": "another_function",
"condition": True,
}
],
}
with open(patch_json, "w", encoding="utf-8") as f:
json.dump(test_patches, f, indent=2)
log_file = Path(temp_dir) / "test.log"
log_content = (
"=== In patch call, origin_import: megatron.test.module.function, "
"patch_import: mindspeed.test.module.function, patch_name: function, condition: False\n"
"=== In original call\n"
)
with open(log_file, "w", encoding="utf-8") as f:
f.write(log_content)
yield str(patch_json), str(log_file), temp_dir
shutil.rmtree(temp_dir)
@pytest.mark.level1
@pytest.mark.platform_arm_ascend910b_training
@pytest.mark.env_onecard
@pytest.mark.run(order=1)
def test_get_printing_str_basic():
"""
Feature: get_printing_str function formats patch information correctly.
Description: get_printing_str should format patch information into a readable string.
Expectation: Returns correctly formatted string with patch details.
"""
origin_import = "megatron.test.module.function"
raw_patch = {
"patch_import": "mindspeed.test.module.function",
"patch_name": "function",
"condition": False,
}
result = get_printing_str(origin_import, raw_patch)
expected = (
"=== In patch call, origin_import: megatron.test.module.function, "
"patch_import: mindspeed.test.module.function, patch_name: function, condition: False"
)
assert result == expected
@pytest.mark.level1
@pytest.mark.platform_arm_ascend910b_training
@pytest.mark.env_onecard
@pytest.mark.run(order=2)
def test_get_printing_str_with_condition_true():
"""
Feature: get_printing_str handles different condition values.
Description: get_printing_str should format boolean condition values correctly.
Expectation: Condition=True is reflected in the output string.
"""
origin_import = "megatron.test.module.function"
raw_patch = {
"patch_import": "mindspeed.test.module.function",
"patch_name": "function",
"condition": True,
}
result = get_printing_str(origin_import, raw_patch)
expected = (
"=== In patch call, origin_import: megatron.test.module.function, "
"patch_import: mindspeed.test.module.function, patch_name: function, condition: True"
)
assert result == expected
@pytest.mark.level1
@pytest.mark.platform_arm_ascend910b_training
@pytest.mark.env_onecard
@pytest.mark.run(order=3)
def test_get_debug_print_node_with_patch():
"""
Feature: get_debug_print_node creates debug print statements for patches.
Description: get_debug_print_node should create CST nodes for printing patch information.
Expectation: Returns CST node with correct print statement for patch call.
"""
patch_info = {
"origin_import": "megatron.test.module.function",
"raw_patch": {
"patch_import": "mindspeed.test.module.function",
"patch_name": "function",
"condition": False,
},
}
node = get_debug_print_node(patch_info)
assert node is not None
module = cst.Module(body=[node])
code = module.code
assert "print" in code
assert "megatron.test.module.function" in code
assert "mindspeed.test.module.function" in code
@pytest.mark.level1
@pytest.mark.platform_arm_ascend910b_training
@pytest.mark.env_onecard
@pytest.mark.run(order=4)
def test_get_debug_print_node_without_patch():
"""
Feature: get_debug_print_node creates debug print statements for original calls.
Description: When patch is None, original call string should be used.
Expectation: Returns CST node printing '=== In original call'.
"""
node = get_debug_print_node(None)
assert node is not None
module = cst.Module(body=[node])
code = module.code
assert "print" in code
assert "=== In original call" in code
@pytest.mark.level1
@pytest.mark.platform_arm_ascend910b_training
@pytest.mark.env_onecard
@pytest.mark.run(order=5)
def test_check_log_basic_coverage(mock_coverage_environment):
"""
Feature: check_log calculates patch coverage correctly.
Description: check_log should compute module/patch hit statistics and dump not-hit patches.
Expectation: Correct hit counts and not_hit_cases JSON created.
"""
patch_json, log_file, temp_dir = mock_coverage_environment
with patch("builtins.print") as mock_print:
check_log(patch_json, log_file)
assert mock_print.called
print_calls = [call[0][0] for call in mock_print.call_args_list]
assert any("module coverage:" in msg for msg in print_calls)
assert any("patch coverage:" in msg for msg in print_calls)
assert any("Patches not hit were dumped" in msg for msg in print_calls)
not_hit_file = Path(patch_json).parent / "test_patches_not_hit_cases.json"
assert not_hit_file.exists()
with open(not_hit_file, "r", encoding="utf-8") as f:
not_hit_data = json.load(f)
assert "megatron.test.module.another_function" in not_hit_data
assert len(not_hit_data["megatron.test.module.another_function"]) == 1
@pytest.mark.level1
@pytest.mark.platform_arm_ascend910b_training
@pytest.mark.env_onecard
@pytest.mark.run(order=6)
def test_check_log_missing_json_file(tmp_path):
"""
Feature: check_log handles missing JSON file gracefully.
Description: When patch_json_file does not exist, an error message should be printed.
Expectation: No crash and 'not found' message printed.
"""
missing_json = tmp_path / "missing.json"
log_file = tmp_path / "log.txt"
log_file.write_text("dummy log", encoding="utf-8")
with pytest.raises(UnboundLocalError):
check_log(str(missing_json), str(log_file))
@pytest.mark.level1
@pytest.mark.platform_arm_ascend910b_training
@pytest.mark.env_onecard
@pytest.mark.run(order=7)
def test_check_log_missing_log_file(tmp_path):
"""
Feature: check_log handles missing log file gracefully.
Description: When log_file does not exist, an error message should be printed.
Expectation: No crash and 'not found' message printed for log file.
"""
patch_json = tmp_path / "patches.json"
json.dump({"mod": []}, open(patch_json, "w", encoding="utf-8"))
missing_log = tmp_path / "missing.log"
with pytest.raises(ZeroDivisionError):
check_log(str(patch_json), str(missing_log))
@pytest.mark.level1
@pytest.mark.platform_arm_ascend910b_training
@pytest.mark.env_onecard
@pytest.mark.run(order=8)
def test_check_log_invalid_json_file(tmp_path):
"""
Feature: check_log handles invalid JSON file gracefully.
Description: When JSON cannot be parsed, an error message should be printed.
Expectation: No crash and 'not a valid JSON file' printed.
"""
invalid_json = tmp_path / "invalid.json"
invalid_json.write_text("{ invalid json", encoding="utf-8")
log_file = tmp_path / "log.txt"
log_file.write_text("dummy log", encoding="utf-8")
with pytest.raises(UnboundLocalError):
check_log(str(invalid_json), str(log_file))
@pytest.mark.level1
@pytest.mark.platform_arm_ascend910b_training
@pytest.mark.env_onecard
@pytest.mark.run(order=9)
def test_check_log_empty_patches(tmp_path):
"""
Feature: check_log handles empty patch JSON.
Description: When there are no patches, statistics should be 0/0.
Expectation: module coverage and patch coverage are 0/0.
"""
empty_json = tmp_path / "empty.json"
with open(empty_json, "w", encoding="utf-8") as f:
json.dump({}, f)
log_file = tmp_path / "log.txt"
log_file.write_text("dummy log", encoding="utf-8")
with pytest.raises(ZeroDivisionError):
check_log(str(empty_json), str(log_file))
@pytest.mark.level1
@pytest.mark.platform_arm_ascend910b_training
@pytest.mark.env_onecard
@pytest.mark.run(order=10)
def test_check_log_no_hits(tmp_path):
"""
Feature: check_log handles case where no patches are hit.
Description: Log file does not contain any patch print strings.
Expectation: Hit counts are 0 and all patches appear in not_hit_cases.
"""
patch_json = tmp_path / "patches.json"
patches = {
"megatron.mod.func": [
{
"patch_import": "mindspeed.mod.func",
"patch_name": "func",
"condition": False,
}
]
}
with open(patch_json, "w", encoding="utf-8") as f:
json.dump(patches, f)
log_file = tmp_path / "log.txt"
log_file.write_text("=== In original call\n", encoding="utf-8")
with patch("builtins.print") as mock_print:
check_log(str(patch_json), str(log_file))
assert mock_print.called
print_calls = [call[0][0] for call in mock_print.call_args_list]
assert any("module coverage" in msg for msg in print_calls)
assert any("patch coverage" in msg for msg in print_calls)
not_hit_file = tmp_path / "patches_not_hit_cases.json"
assert not_hit_file.exists()
with open(not_hit_file, "r", encoding="utf-8") as f:
not_hit = json.load(f)
assert "megatron.mod.func" in not_hit
assert len(not_hit["megatron.mod.func"]) == 1
@pytest.mark.level1
@pytest.mark.platform_arm_ascend910b_training
@pytest.mark.env_onecard
@pytest.mark.run(order=11)
def test_check_log_full_coverage(tmp_path):
"""
Feature: check_log handles case where all patches are hit.
Description: Log file contains all patch print strings.
Expectation: Hit counts equal to total counts (100% coverage).
"""
patch_json = tmp_path / "patches.json"
patches = {
"megatron.mod.func1": [
{
"patch_import": "mindspeed.mod.func1",
"patch_name": "func1",
"condition": False,
}
],
"megatron.mod.func2": [
{
"patch_import": "mindspeed.mod.func2",
"patch_name": "func2",
"condition": True,
}
],
}
with open(patch_json, "w", encoding="utf-8") as f:
json.dump(patches, f)
from tools.convert.patch_merge.modules.coverage import get_printing_str as _gps
lines = []
for origin, plist in patches.items():
for p in plist:
lines.append(_gps(origin, p))
log_file = tmp_path / "log.txt"
log_file.write_text("\n".join(lines), encoding="utf-8")
with patch("builtins.print") as mock_print:
check_log(str(patch_json), str(log_file))
assert mock_print.called
print_calls = [call[0][0] for call in mock_print.call_args_list]
assert any("module coverage: 2/2" in msg for msg in print_calls)
assert any("patch coverage: 2/2" in msg for msg in print_calls)
assert any("ratio=1.000" in msg for msg in print_calls)
@pytest.mark.level1
@pytest.mark.platform_arm_ascend910b_training
@pytest.mark.env_onecard
@pytest.mark.run(order=12)
def test_check_log_quoted_import_keys(tmp_path):
"""
Feature: check_log strips quotes from original import keys.
Description: Keys in JSON may be quoted ('megatron.xxx' or \"megatron.xxx\").
Expectation: Coverage detection still works and keys in not_hit_cases are unquoted.
"""
patch_json = tmp_path / "patches.json"
patches = {
"'megatron.mod.func'": [
{
"patch_import": "mindspeed.mod.func",
"patch_name": "func",
"condition": False,
}
]
}
with open(patch_json, "w", encoding="utf-8") as f:
json.dump(patches, f)
from tools.convert.patch_merge.modules.coverage import get_printing_str as _gps
origin_import_unquoted = "megatron.mod.func"
log_line = _gps(origin_import_unquoted, patches["'megatron.mod.func'"][0])
log_file = tmp_path / "log.txt"
log_file.write_text(log_line, encoding="utf-8")
with patch("builtins.print") as mock_print:
check_log(str(patch_json), str(log_file))
assert mock_print.called
print_calls = [call[0][0] for call in mock_print.call_args_list]
assert any("module coverage: 1/1" in msg for msg in print_calls)
assert any("patch coverage: 1/1" in msg for msg in print_calls)