"""文档分块处理器"""

import re
import hashlib
from pathlib import Path
from typing import List, Dict, Any
from dataclasses import dataclass


@dataclass
class TextChunk:
    """文本块数据结构"""

    content: str
    source: str
    chunk_id: str
    metadata: Dict[str, Any]

    def __post_init__(self):
        if not self.chunk_id:
            self.chunk_id = self._generate_id()

    def _generate_id(self) -> str:
        """生成唯一ID"""
        content_hash = hashlib.sha256(self.content.encode()).hexdigest()[:16]
        source_hash = hashlib.sha256(self.source.encode()).hexdigest()[:8]
        return f"{source_hash}_{content_hash}"


class DocumentProcessor:
    """文档处理器 - 负责读取和分块文档"""

    HEADING_PATTERN = re.compile(r"^(#{1,6})\s+(.+)$", re.MULTILINE)
    CODE_BLOCK_PATTERN = re.compile(r"```[\s\S]*?```")

    def __init__(
        self,
        chunk_size: int = 1000,
        chunk_overlap: int = 200,
        preserve_structure: bool = True,
    ):
        self.chunk_size = chunk_size
        self.chunk_overlap = chunk_overlap
        self.preserve_structure = preserve_structure

    def load_document(self, file_path: str) -> str:
        """加载文档内容"""
        path = Path(file_path)
        if not path.exists():
            raise FileNotFoundError(f"File not found: {file_path}")
        return path.read_text(encoding="utf-8")

    def process(self, text: str, source: str = "") -> List[TextChunk]:
        """处理文本并生成块"""
        if self.preserve_structure:
            return self._chunk_by_structure(text, source)
        return self._chunk_by_size(text, source)

    def _chunk_by_structure(self, text: str, source: str) -> List[TextChunk]:
        """按文档结构(标题)分块"""
        chunks = []

        lines = text.split("\n")
        sections = self._extract_sections(lines)

        for section in sections:
            section_text = "\n".join(section["lines"]).strip()
            if not section_text:
                continue

            if len(section_text) <= self.chunk_size:
                chunks.append(
                    TextChunk(
                        content=section_text,
                        source=source,
                        chunk_id="",
                        metadata={
                            "heading": section["heading"],
                            "level": section["level"],
                            "start_line": section["start"],
                            "end_line": section["end"],
                        },
                    )
                )
            else:
                sub_chunks = self._split_large_text(
                    section_text, source, section["heading"], section["level"]
                )
                chunks.extend(sub_chunks)

        return chunks

    def _extract_sections(self, lines: List[str]) -> List[Dict]:
        """提取文档章节"""
        sections = []
        current_section = {"heading": "", "level": 0, "start": 0, "lines": []}

        for i, line in enumerate(lines):
            heading_match = self.HEADING_PATTERN.match(line)

            if heading_match:
                if current_section["lines"]:
                    current_section["end"] = i - 1
                    sections.append(current_section.copy())

                current_section = {
                    "heading": heading_match.group(2).strip(),
                    "level": len(heading_match.group(1)),
                    "start": i,
                    "lines": [line],
                }
            else:
                current_section["lines"].append(line)

        if current_section["lines"]:
            current_section["end"] = len(lines) - 1
            sections.append(current_section)

        return sections

    def _split_large_text(
        self, text: str, source: str, heading: str, level: int
    ) -> List[TextChunk]:
        """分割大文本"""
        chunks = []
        sentences = self._split_into_sentences(text)

        current_chunk = []
        current_length = 0

        for sentence in sentences:
            if current_length + len(sentence) > self.chunk_size and current_chunk:
                chunk_text = " ".join(current_chunk)
                chunks.append(
                    TextChunk(
                        content=chunk_text,
                        source=source,
                        chunk_id="",
                        metadata={
                            "heading": heading,
                            "level": level,
                            "is_subchunk": True,
                        },
                    )
                )

                overlap = self._get_overlap_sentences(current_chunk)
                current_chunk = overlap
                current_length = sum(len(" ".join(current_chunk)))

            current_chunk.append(sentence)
            current_length += len(sentence)

        if current_chunk:
            chunk_text = " ".join(current_chunk)
            chunks.append(
                TextChunk(
                    content=chunk_text,
                    source=source,
                    chunk_id="",
                    metadata={"heading": heading, "level": level, "is_subchunk": True},
                )
            )

        return chunks

    def _split_into_sentences(self, text: str) -> List[str]:
        """分割成句子"""
        sentences = re.split(r"(?<=[.!?。!?])\s+", text)
        return [s.strip() for s in sentences if s.strip()]

    def _get_overlap_sentences(self, sentences: List[str]) -> List[str]:
        """获取重叠句子"""
        if self.chunk_overlap <= 0:
            return []

        overlap_text = " ".join(sentences)
        if len(overlap_text) <= self.chunk_overlap:
            return sentences

        result = []
        total_length = 0

        for sentence in reversed(sentences):
            if total_length + len(sentence) > self.chunk_overlap:
                break
            result.insert(0, sentence)
            total_length += len(sentence)

        return result

    def _chunk_by_size(self, text: str, source: str) -> List[TextChunk]:
        """按固定大小分块"""
        chunks = []
        length = len(text)

        for i in range(0, length, self.chunk_size - self.chunk_overlap):
            end = min(i + self.chunk_size, length)
            chunk_text = text[i:end].strip()

            if chunk_text:
                chunks.append(
                    TextChunk(
                        content=chunk_text,
                        source=source,
                        chunk_id="",
                        metadata={"start": i, "end": end, "method": "size_based"},
                    )
                )

        return chunks