"""
HandwriteLoader和HandwriteSampler的单元测试 - 精简版
测试新的目录结构
"""
import pytest
import tempfile
import shutil
from pathlib import Path
from unittest.mock import Mock, patch, AsyncMock
from akg_agents.op.utils.handwrite_loader import HandwriteLoader, HandwriteSampler
@pytest.fixture
def mock_loader():
"""创建Mock的HandwriteLoader(10个文档,供多个测试类共享)"""
loader = Mock(spec=HandwriteLoader)
mock_pairs = [
{
'name': f'static_shape/reduction/opt_{i:02d}',
'file_stem': f'opt_{i:02d}',
'shape_type': 'static_shape',
'category': 'reduction',
'torch_file': Path(f'/tmp/opt_{i:02d}.py'),
'triton_file': Path(f'/tmp/opt_{i:02d}.py'),
'improvement_file': Path(f'/tmp/opt_{i:02d}.md')
}
for i in range(10)
]
loader.get_selected_pairs.return_value = mock_pairs
loader.read_pair_content.side_effect = lambda p: {
'name': p['name'],
'framework_code': f"framework_code {p['name']}",
'impl_code': f"impl_code {p['name']}",
'improvement_doc': f"improve {p['name']}",
}
return loader
@pytest.fixture
def temp_handwrite_dir():
"""创建临时的手写文件目录结构"""
top_dir = tempfile.mkdtemp()
temp_dir = Path(top_dir) / "python" / "akg_agents"
temp_dir.mkdir(parents=True, exist_ok=True)
temp_dir = str(temp_dir)
akg_agents_root = Path(temp_dir).parent.parent
benchmark_root = akg_agents_root / "benchmark"
akg_kernels_bench_root = benchmark_root / "akg_kernels_bench"
torch_base = akg_kernels_bench_root
triton_impl_base = akg_kernels_bench_root / "triton_ascend" / "impl"
triton_docs_base = akg_kernels_bench_root / "triton_ascend" / "docs"
test_files = [
("dynamic_shape", "reduction", "softmax_001", "Softmax"),
("dynamic_shape", "reduction", "layernorm_001", "LayerNorm"),
("static_shape", "reduction", "relu_001", "ReLU"),
("static_shape", "sorting", "topk_001", "TopK"),
("static_shape", "reduction", "gelu_001", "GELU"),
]
for shape_type, category, name, desc in test_files:
torch_dir = torch_base / shape_type / category
torch_dir.mkdir(parents=True, exist_ok=True)
(torch_dir / f"{name}.py").write_text(f"# Torch {desc}\ndef {name}_torch(): pass")
triton_impl_dir = triton_impl_base / shape_type / category
triton_impl_dir.mkdir(parents=True, exist_ok=True)
(triton_impl_dir / f"{name}.py").write_text(f"# Triton {desc}\n@triton.jit\ndef {name}_kernel(): pass")
triton_docs_dir = triton_docs_base / shape_type / category
triton_docs_dir.mkdir(parents=True, exist_ok=True)
(triton_docs_dir / f"{name}.md").write_text(f"# {desc} Optimization\nSuggestions for {name}")
yield {
'temp_dir': temp_dir,
'benchmark_root': benchmark_root,
'akg_kernels_bench_root': akg_kernels_bench_root,
'torch_base': torch_base,
'triton_impl_base': triton_impl_base,
'triton_docs_base': triton_docs_base,
'test_files': test_files
}
shutil.rmtree(top_dir, ignore_errors=True)
class TestHandwriteLoaderCore:
"""测试HandwriteLoader核心功能"""
def test_load_and_read(self, temp_handwrite_dir):
"""测试1: 加载所有文件并读取内容"""
with patch('akg_agents.op.utils.handwrite_loader.get_project_root') as mock_root:
mock_root.return_value = str(temp_handwrite_dir['temp_dir'])
loader = HandwriteLoader(dsl="triton_ascend")
loader._init_filesystem_mode()
loader._load_data_pairs()
assert len(loader._all_data_pairs) == 5
assert len(loader._selected_data_pairs) == 5
first_pair = loader._all_data_pairs[0]
assert 'name' in first_pair
assert 'framework_path' in first_pair
assert first_pair['name'].count('/') == 2
content = loader.read_pair_content(first_pair)
assert isinstance(content, dict)
assert all(k in content for k in ['name', 'framework_code', 'impl_code', 'improvement_doc'])
assert 'Triton' in content['impl_code']
assert 'Torch' in content['framework_code']
assert 'Optimization' in content['improvement_doc']
@pytest.mark.asyncio
async def test_select_with_mock_llm(self, temp_handwrite_dir):
"""测试2: LLM筛选功能(Mock)"""
with patch('akg_agents.op.utils.handwrite_loader.get_project_root') as mock_root:
mock_root.return_value = str(temp_handwrite_dir['temp_dir'])
loader = HandwriteLoader(
dsl="triton_ascend",
op_name="relu_op",
task_desc="ReLU activation",
config={'agent_model_config': {'default': {}}}
)
with patch('akg_agents.op.utils.handwrite_loader.Selector') as MockSelector:
mock_selector = MockSelector.return_value
mock_selector.run = AsyncMock(return_value=[
'static_shape/reduction/relu_001',
'static_shape/reduction/gelu_001'
])
await loader.select_relevant_pairs()
assert len(loader._selected_data_pairs) == 2
selected_names = {p['name'] for p in loader._selected_data_pairs}
assert 'static_shape/reduction/relu_001' in selected_names
assert 'static_shape/reduction/gelu_001' in selected_names
class TestHandwriteSamplerCore:
"""测试HandwriteSampler核心功能(使用模块级mock_loader)"""
def test_basic_sampling(self, mock_loader):
"""测试3: 基本采样和不重复"""
sampler = HandwriteSampler(loader=mock_loader, sample_num=3)
s1 = sampler.sample()
assert len(s1) == 3
s2 = sampler.sample()
assert len(s2) == 3
names1 = {x['name'] for x in s1}
names2 = {x['name'] for x in s2}
assert len(names1 & names2) == 0
def test_reset_when_exhausted(self, mock_loader):
"""测试4: 用完后自动重置"""
sampler = HandwriteSampler(loader=mock_loader, sample_num=3)
for _ in range(3):
sampler.sample()
s4 = sampler.sample()
assert len(s4) == 1
s5 = sampler.sample()
assert len(s5) == 3
def test_independent_samplers(self, mock_loader):
"""测试5: 多个sampler独立性"""
samplers = [HandwriteSampler(loader=mock_loader, sample_num=2) for _ in range(3)]
results = [s.sample() for s in samplers]
assert all(len(r) == 2 for r in results)
assert all(len(s._used_indices) == 2 for s in samplers)
class TestWeightedSampling:
"""测试加权采样功能"""
@pytest.fixture
def mock_loader_large(self):
"""创建包含更多文档的Mock HandwriteLoader"""
loader = Mock(spec=HandwriteLoader)
mock_pairs = [
{
'name': f'static_shape/reduction/opt_{i:02d}',
'file_stem': f'opt_{i:02d}',
'shape_type': 'static_shape',
'category': 'reduction',
'torch_file': Path(f'/tmp/opt_{i:02d}.py'),
'triton_file': Path(f'/tmp/opt_{i:02d}.py'),
'improvement_file': Path(f'/tmp/opt_{i:02d}.md')
}
for i in range(30)
]
loader.get_selected_pairs.return_value = mock_pairs
loader.read_pair_content.side_effect = lambda p: {
'name': p['name'],
'framework_code': f"framework_code {p['name']}",
'impl_code': f"impl_code {p['name']}",
'improvement_doc': f"improve {p['name']}",
}
return loader
def test_weight_initialization(self, mock_loader_large):
"""测试6: 权重初始化和计算"""
sampler = HandwriteSampler(loader=mock_loader_large, sample_num=2, decay_rate=2.0)
assert hasattr(sampler, '_weights')
assert len(sampler._weights) == 30
assert sampler._weights[0] > sampler._weights[1] > sampler._weights[-1]
import numpy as np
assert abs(sampler._weights.sum() - 1.0) < 1e-6
def test_different_decay_rates(self, mock_loader_large):
"""测试7: 不同衰减率的效果"""
sampler_low = HandwriteSampler(loader=mock_loader_large, sample_num=2, decay_rate=1.0)
sampler_mid = HandwriteSampler(loader=mock_loader_large, sample_num=2, decay_rate=2.0)
sampler_high = HandwriteSampler(loader=mock_loader_large, sample_num=2, decay_rate=5.0)
ratio_low = sampler_low._weights[0] / sampler_low._weights[-1]
ratio_mid = sampler_mid._weights[0] / sampler_mid._weights[-1]
ratio_high = sampler_high._weights[0] / sampler_high._weights[-1]
assert ratio_low < ratio_mid < ratio_high
def test_weighted_sampling_preserves_no_repeat(self, mock_loader_large):
"""测试8: 加权采样保持不重复特性"""
sampler = HandwriteSampler(loader=mock_loader_large, sample_num=3, decay_rate=2.0)
all_sampled = []
for _ in range(10):
samples = sampler.sample()
names = [s['name'] for s in samples]
assert len(names) == len(set(names))
all_sampled.extend(names)
assert len(all_sampled) == 30
assert len(set(all_sampled)) == 30
def test_weighted_sampling_preserves_reset(self, mock_loader_large):
"""测试9: 加权采样保持自动重置特性"""
sampler = HandwriteSampler(loader=mock_loader_large, sample_num=5, decay_rate=2.0)
for _ in range(6):
samples = sampler.sample()
assert len(samples) == 5
samples = sampler.sample()
assert len(samples) == 5
assert len(sampler._used_indices) == 5
def test_weighted_sampling_bias(self, mock_loader_large):
"""测试10: 加权采样确实偏向前面的文档"""
sampler = HandwriteSampler(loader=mock_loader_large, sample_num=2, decay_rate=2.0)
top_5_count = 0
bottom_5_count = 0
for _ in range(100):
sampler.reset()
samples = sampler.sample()
for sample in samples:
name = sample['name']
idx = int(name.split('_')[-1])
if idx < 5:
top_5_count += 1
elif idx >= 25:
bottom_5_count += 1
assert top_5_count > bottom_5_count
assert top_5_count > bottom_5_count * 1.5
def test_manual_reset(self, mock_loader_large):
"""测试11: 手动重置功能"""
sampler = HandwriteSampler(loader=mock_loader_large, sample_num=5, decay_rate=2.0)
sampler.sample()
assert len(sampler._used_indices) == 5
sampler.reset()
assert len(sampler._used_indices) == 0
samples = sampler.sample()
assert len(samples) == 5
class TestMultiRoundScenario:
"""测试多轮采样场景(使用模块级mock_loader)"""
def test_multi_island_multi_round(self, mock_loader):
"""测试12: 多岛屿多轮采样(集成测试)"""
num_islands = 3
num_rounds = 2
sample_num = 2
island_samplers = [
HandwriteSampler(loader=mock_loader, sample_num=sample_num)
for _ in range(num_islands)
]
for round_idx in range(num_rounds):
for island_idx in range(num_islands):
suggestions = island_samplers[island_idx].sample()
assert 1 <= len(suggestions) <= sample_num
for sampler in island_samplers:
assert len(sampler._used_indices) == 4