# Copyright 2025 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
HandwriteLoader和HandwriteSampler的单元测试 - 精简版
测试新的目录结构
"""

import pytest
import tempfile
import shutil
from pathlib import Path
from unittest.mock import Mock, patch, AsyncMock
from akg_agents.op.utils.handwrite_loader import HandwriteLoader, HandwriteSampler


@pytest.fixture
def mock_loader():
    """创建Mock的HandwriteLoader(10个文档,供多个测试类共享)"""
    loader = Mock(spec=HandwriteLoader)
    
    mock_pairs = [
        {
            'name': f'static_shape/reduction/opt_{i:02d}',
            'file_stem': f'opt_{i:02d}',
            'shape_type': 'static_shape',
            'category': 'reduction',
            'torch_file': Path(f'/tmp/opt_{i:02d}.py'),
            'triton_file': Path(f'/tmp/opt_{i:02d}.py'),
            'improvement_file': Path(f'/tmp/opt_{i:02d}.md')
        }
        for i in range(10)
    ]
    
    loader.get_selected_pairs.return_value = mock_pairs
    loader.read_pair_content.side_effect = lambda p: {
        'name': p['name'],
        'framework_code': f"framework_code {p['name']}",
        'impl_code': f"impl_code {p['name']}",
        'improvement_doc': f"improve {p['name']}",
    }
    
    return loader


@pytest.fixture
def temp_handwrite_dir():
    """创建临时的手写文件目录结构"""
    # 创建一个顶层临时目录,并在其中模拟完整的项目层级
    # HandwriteLoader._init_filesystem_mode 会对 project_root 执行 .parent.parent
    # 因此 project_root 需要嵌套至少 2 层,确保 .parent.parent 仍在临时目录内
    top_dir = tempfile.mkdtemp()
    
    # 模拟实际项目结构:
    # top_dir/                              <- 顶层临时目录 (相当于 akg_agents/)
    #   ├── python/akg_agents/              <- project_root (get_project_root 返回值)
    #   └── benchmark/akg_kernels_bench/    <- 手写数据目录
    #       ├── triton_ascend/
    #       │   ├── impl/
    #       │   └── docs/
    #       ├── dynamic_shape/
    #       └── static_shape/
    
    temp_dir = Path(top_dir) / "python" / "akg_agents"
    temp_dir.mkdir(parents=True, exist_ok=True)
    temp_dir = str(temp_dir)
    
    # 从 temp_dir 往上2级回到 top_dir,然后创建 benchmark
    akg_agents_root = Path(temp_dir).parent.parent
    benchmark_root = akg_agents_root / "benchmark"
    akg_kernels_bench_root = benchmark_root / "akg_kernels_bench"
    torch_base = akg_kernels_bench_root
    triton_impl_base = akg_kernels_bench_root / "triton_ascend" / "impl"
    triton_docs_base = akg_kernels_bench_root / "triton_ascend" / "docs"
    
    # 创建测试文件(包含dynamic_shape和static_shape)
    test_files = [
        ("dynamic_shape", "reduction", "softmax_001", "Softmax"),
        ("dynamic_shape", "reduction", "layernorm_001", "LayerNorm"),
        ("static_shape", "reduction", "relu_001", "ReLU"),
        ("static_shape", "sorting", "topk_001", "TopK"),
        ("static_shape", "reduction", "gelu_001", "GELU"),
    ]
    
    for shape_type, category, name, desc in test_files:
        # 创建torch文件
        torch_dir = torch_base / shape_type / category
        torch_dir.mkdir(parents=True, exist_ok=True)
        (torch_dir / f"{name}.py").write_text(f"# Torch {desc}\ndef {name}_torch(): pass")
        
        # 创建triton实现文件
        triton_impl_dir = triton_impl_base / shape_type / category
        triton_impl_dir.mkdir(parents=True, exist_ok=True)
        (triton_impl_dir / f"{name}.py").write_text(f"# Triton {desc}\n@triton.jit\ndef {name}_kernel(): pass")
        
        # 创建优化建议文件
        triton_docs_dir = triton_docs_base / shape_type / category
        triton_docs_dir.mkdir(parents=True, exist_ok=True)
        (triton_docs_dir / f"{name}.md").write_text(f"# {desc} Optimization\nSuggestions for {name}")
    
    yield {
        'temp_dir': temp_dir,
        'benchmark_root': benchmark_root,
        'akg_kernels_bench_root': akg_kernels_bench_root,
        'torch_base': torch_base,
        'triton_impl_base': triton_impl_base,
        'triton_docs_base': triton_docs_base,
        'test_files': test_files
    }
    
    # 清理顶层临时目录(包含 python/ 和 benchmark/ 子目录)
    shutil.rmtree(top_dir, ignore_errors=True)


class TestHandwriteLoaderCore:
    """测试HandwriteLoader核心功能"""
    
    def test_load_and_read(self, temp_handwrite_dir):
        """测试1: 加载所有文件并读取内容"""
        with patch('akg_agents.op.utils.handwrite_loader.get_project_root') as mock_root:
            mock_root.return_value = str(temp_handwrite_dir['temp_dir'])
            
            loader = HandwriteLoader(dsl="triton_ascend")
            # 显式初始化文件系统模式并加载数据
            loader._init_filesystem_mode()
            loader._load_data_pairs()
            
            # 验证加载
            assert len(loader._all_data_pairs) == 5
            assert len(loader._selected_data_pairs) == 5
            
            # 验证数据对结构
            first_pair = loader._all_data_pairs[0]
            assert 'name' in first_pair
            assert 'framework_path' in first_pair
            assert first_pair['name'].count('/') == 2  # shape_type/category/file_stem
            
            # 验证读取
            content = loader.read_pair_content(first_pair)
            
            assert isinstance(content, dict)
            assert all(k in content for k in ['name', 'framework_code', 'impl_code', 'improvement_doc'])
            assert 'Triton' in content['impl_code']
            assert 'Torch' in content['framework_code']
            assert 'Optimization' in content['improvement_doc']
    
    @pytest.mark.asyncio
    async def test_select_with_mock_llm(self, temp_handwrite_dir):
        """测试2: LLM筛选功能(Mock)"""
        with patch('akg_agents.op.utils.handwrite_loader.get_project_root') as mock_root:
            mock_root.return_value = str(temp_handwrite_dir['temp_dir'])
            
            loader = HandwriteLoader(
                dsl="triton_ascend",
                op_name="relu_op",
                task_desc="ReLU activation",
                config={'agent_model_config': {'default': {}}}
            )
            
            with patch('akg_agents.op.utils.handwrite_loader.Selector') as MockSelector:
                mock_selector = MockSelector.return_value
                # LLM返回完整路径名称
                mock_selector.run = AsyncMock(return_value=[
                    'static_shape/reduction/relu_001',
                    'static_shape/reduction/gelu_001'
                ])
                
                await loader.select_relevant_pairs()
                
                assert len(loader._selected_data_pairs) == 2
                selected_names = {p['name'] for p in loader._selected_data_pairs}
                assert 'static_shape/reduction/relu_001' in selected_names
                assert 'static_shape/reduction/gelu_001' in selected_names


class TestHandwriteSamplerCore:
    """测试HandwriteSampler核心功能(使用模块级mock_loader)"""
    
    def test_basic_sampling(self, mock_loader):
        """测试3: 基本采样和不重复"""
        sampler = HandwriteSampler(loader=mock_loader, sample_num=3)
        
        # 第一次采样
        s1 = sampler.sample()
        assert len(s1) == 3
        
        # 第二次采样不重复
        s2 = sampler.sample()
        assert len(s2) == 3
        names1 = {x['name'] for x in s1}
        names2 = {x['name'] for x in s2}
        assert len(names1 & names2) == 0  # 无交集
    
    def test_reset_when_exhausted(self, mock_loader):
        """测试4: 用完后自动重置"""
        sampler = HandwriteSampler(loader=mock_loader, sample_num=3)
        
        # 采样3次 = 9个
        for _ in range(3):
            sampler.sample()
        
        # 第4次只剩1个
        s4 = sampler.sample()
        assert len(s4) == 1
        
        # 第5次重置后又是3个
        s5 = sampler.sample()
        assert len(s5) == 3
    
    def test_independent_samplers(self, mock_loader):
        """测试5: 多个sampler独立性"""
        samplers = [HandwriteSampler(loader=mock_loader, sample_num=2) for _ in range(3)]
        
        # 每个独立采样
        results = [s.sample() for s in samplers]
        
        # 都能正常采样
        assert all(len(r) == 2 for r in results)
        
        # 各自的状态独立
        assert all(len(s._used_indices) == 2 for s in samplers)


class TestWeightedSampling:
    """测试加权采样功能"""
    
    @pytest.fixture
    def mock_loader_large(self):
        """创建包含更多文档的Mock HandwriteLoader"""
        loader = Mock(spec=HandwriteLoader)
        
        mock_pairs = [
            {
                'name': f'static_shape/reduction/opt_{i:02d}',
                'file_stem': f'opt_{i:02d}',
                'shape_type': 'static_shape',
                'category': 'reduction',
                'torch_file': Path(f'/tmp/opt_{i:02d}.py'),
                'triton_file': Path(f'/tmp/opt_{i:02d}.py'),
                'improvement_file': Path(f'/tmp/opt_{i:02d}.md')
            }
            for i in range(30)
        ]
        
        loader.get_selected_pairs.return_value = mock_pairs
        loader.read_pair_content.side_effect = lambda p: {
            'name': p['name'],
            'framework_code': f"framework_code {p['name']}",
            'impl_code': f"impl_code {p['name']}",
            'improvement_doc': f"improve {p['name']}",
        }
        
        return loader
    
    def test_weight_initialization(self, mock_loader_large):
        """测试6: 权重初始化和计算"""
        sampler = HandwriteSampler(loader=mock_loader_large, sample_num=2, decay_rate=2.0)
        
        # 验证权重已计算
        assert hasattr(sampler, '_weights')
        assert len(sampler._weights) == 30
        
        # 验证权重递减
        assert sampler._weights[0] > sampler._weights[1] > sampler._weights[-1]
        
        # 验证权重和为1(归一化)
        import numpy as np
        assert abs(sampler._weights.sum() - 1.0) < 1e-6
    
    def test_different_decay_rates(self, mock_loader_large):
        """测试7: 不同衰减率的效果"""
        sampler_low = HandwriteSampler(loader=mock_loader_large, sample_num=2, decay_rate=1.0)
        sampler_mid = HandwriteSampler(loader=mock_loader_large, sample_num=2, decay_rate=2.0)
        sampler_high = HandwriteSampler(loader=mock_loader_large, sample_num=2, decay_rate=5.0)
        
        # 权重比应该随衰减率增加而增加
        ratio_low = sampler_low._weights[0] / sampler_low._weights[-1]
        ratio_mid = sampler_mid._weights[0] / sampler_mid._weights[-1]
        ratio_high = sampler_high._weights[0] / sampler_high._weights[-1]
        
        assert ratio_low < ratio_mid < ratio_high
    
    def test_weighted_sampling_preserves_no_repeat(self, mock_loader_large):
        """测试8: 加权采样保持不重复特性"""
        sampler = HandwriteSampler(loader=mock_loader_large, sample_num=3, decay_rate=2.0)
        
        # 采样10次
        all_sampled = []
        for _ in range(10):
            samples = sampler.sample()
            names = [s['name'] for s in samples]
            
            # 每次采样内部不重复
            assert len(names) == len(set(names))
            
            all_sampled.extend(names)
        
        # 前10次采样共30个,应该覆盖所有文档且没有重复
        assert len(all_sampled) == 30
        assert len(set(all_sampled)) == 30
    
    def test_weighted_sampling_preserves_reset(self, mock_loader_large):
        """测试9: 加权采样保持自动重置特性"""
        sampler = HandwriteSampler(loader=mock_loader_large, sample_num=5, decay_rate=2.0)
        
        # 采样6次 = 30个,正好用完
        for _ in range(6):
            samples = sampler.sample()
            assert len(samples) == 5
        
        # 第7次应该重置并继续采样
        samples = sampler.sample()
        assert len(samples) == 5
        
        # 验证已重置
        assert len(sampler._used_indices) == 5
    
    def test_weighted_sampling_bias(self, mock_loader_large):
        """测试10: 加权采样确实偏向前面的文档"""
        sampler = HandwriteSampler(loader=mock_loader_large, sample_num=2, decay_rate=2.0)
        
        # 统计前5个和后5个文档被采样的次数
        top_5_count = 0
        bottom_5_count = 0
        
        # 多次采样统计
        for _ in range(100):
            # 重置采样器以便重复采样
            sampler.reset()
            samples = sampler.sample()
            
            for sample in samples:
                name = sample['name']
                idx = int(name.split('_')[-1])
                
                if idx < 5:
                    top_5_count += 1
                elif idx >= 25:
                    bottom_5_count += 1
        
        # 前5个文档被采样次数应该明显多于后5个
        assert top_5_count > bottom_5_count
        # 至少1.5倍差异(理论上应该更高,但考虑随机性)
        assert top_5_count > bottom_5_count * 1.5
    
    def test_manual_reset(self, mock_loader_large):
        """测试11: 手动重置功能"""
        sampler = HandwriteSampler(loader=mock_loader_large, sample_num=5, decay_rate=2.0)
        
        # 采样一些
        sampler.sample()
        assert len(sampler._used_indices) == 5
        
        # 手动重置
        sampler.reset()
        assert len(sampler._used_indices) == 0
        
        # 重置后可以重新采样
        samples = sampler.sample()
        assert len(samples) == 5


class TestMultiRoundScenario:
    """测试多轮采样场景(使用模块级mock_loader)"""
    
    def test_multi_island_multi_round(self, mock_loader):
        """测试12: 多岛屿多轮采样(集成测试)"""
        num_islands = 3
        num_rounds = 2
        sample_num = 2
        
        # 每个岛屿独立sampler(共用模块级mock_loader fixture)
        island_samplers = [
            HandwriteSampler(loader=mock_loader, sample_num=sample_num)
            for _ in range(num_islands)
        ]
        
        # 模拟2轮
        for round_idx in range(num_rounds):
            for island_idx in range(num_islands):
                suggestions = island_samplers[island_idx].sample()
                assert 1 <= len(suggestions) <= sample_num
        
        # 验证每个岛屿都采样到了多个不同的建议
        for sampler in island_samplers:
            # 2轮 × 2个 = 4个采样
            assert len(sampler._used_indices) == 4