import uuid
from abc import ABCMeta, abstractmethod
from typing import Any, Callable, Tuple
import tiktoken
from openjiuwen.core.common.logging import logger
from openjiuwen.core.retrieval.common.document import Document, TextChunk
from transformers import PreTrainedTokenizerBase
from openjiuwen_deepsearch.algorithm.search_index.splitter import SentenceSplitter
DEFAULT_CHUNK_SIZE = 200
DEFAULT_CHAR_CHUNK_SIZE = 200
DEFAULT_CHAR_CHUNK_OVERLAP = 40
class TextSplitter(metaclass=ABCMeta):
@abstractmethod
def split(self, text: TextChunk) -> list[TextChunk]:
pass
class CharSplitter(TextSplitter):
"""Simple text splitter based on character length, no dependency on tokenizer."""
def __init__(
self, chunk_size: int | None = None, chunk_overlap: int | None = None
) -> None:
super().__init__()
size = chunk_size or DEFAULT_CHAR_CHUNK_SIZE
overlap = (
chunk_overlap if chunk_overlap is not None else DEFAULT_CHAR_CHUNK_OVERLAP
)
overlap = max(0, min(overlap, size - 1))
self.chunk_size = max(1, size)
self.chunk_overlap = overlap
def split(self, doc: Document) -> list[TextChunk]:
text = doc.text or ""
doc_id = doc.id_
meta = doc.metadata or {}
res: list[TextChunk] = []
step = (
self.chunk_size - self.chunk_overlap
if self.chunk_size > self.chunk_overlap
else self.chunk_size
)
start = 0
while start < len(text):
end = min(len(text), start + self.chunk_size)
res.append(
TextChunk(
id_=str(uuid.uuid4()),
text=text[start:end],
doc_id=doc_id,
metadata=meta,
)
)
start += step
return res
class IndexSentenceSplitter(TextSplitter):
def __init__(
self,
tokenizer: PreTrainedTokenizerBase | Any = None,
chunk_size: int | None = None,
chunk_overlap: int | None = None,
splitter_config: dict | None = None,
) -> None:
"""Wrapper with sentence splitting capabilities.
Args:
tokenizer (PreTrainedTokenizerBase): Tokenizer.
chunk_size (int | None, optional): Chunk size to split documents into passages. Defaults to None.
Note: this is based on tokens produced by the tokenizer of embedding model.
If None, set to the maximum sequence length of the embedding model.
chunk_overlap (int | None, optional): Window size for passage overlap. Defaults to None.
If None, set to `chunk_size // 5`.
splitter_config (dict, optional): Other arguments to SentenceSplitter. Defaults to None.
"""
super().__init__()
self._tokenizer = tokenizer
if not isinstance(splitter_config, dict):
splitter_config = {
"paragraph_separator": "\n",
}
tokenizer_fn, max_token_length = self._resolve_tokenizer(self._tokenizer)
chunk_size = self._resolve_chunk_size(chunk_size, max_token_length)
self._splitter = SentenceSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap or chunk_size // 5,
tokenizer=tokenizer,
)
@staticmethod
def _resolve_tokenizer(
tokenizer: PreTrainedTokenizerBase | Any,
) -> Tuple[Callable[[str], list], int | None]:
"""
Return a tokenizer callable and its max token length (if known).
Falls back to tiktoken when the embedding model has no tokenizer.
"""
if tokenizer is not None:
if hasattr(tokenizer, "tokenize"):
return tokenizer.tokenize, IndexSentenceSplitter._max_length(tokenizer)
if hasattr(tokenizer, "encode"):
return tokenizer.encode, IndexSentenceSplitter._max_length(tokenizer)
if callable(tokenizer):
return tokenizer, IndexSentenceSplitter._max_length(tokenizer)
try:
encoding = tiktoken.get_encoding("cl100k_base")
return encoding.encode, getattr(encoding, "max_token_value", None)
except Exception as exc:
logger.warning("Failed to load tiktoken fallback tokenizer: %s", exc)
return lambda text: text.split(), None
@staticmethod
def _max_length(tokenizer: Any) -> int | None:
"""Try to infer a reasonable maximum token length from a tokenizer."""
for attr in (
"model_max_length",
"max_len_single_sentence",
"max_position_embeddings",
"max_seq_length",
):
if hasattr(tokenizer, attr):
try:
val = int(getattr(tokenizer, attr))
if val and val != float("inf"):
return val
except Exception:
logger.warning("Failed to get max length", exc_info=True)
continue
return None
@staticmethod
def _resolve_chunk_size(
chunk_size: int | None,
max_token_length: int | None,
) -> int:
"""
Decide chunk_size based on caller input and tokenizer limits.
"""
if chunk_size is None and max_token_length:
return max_token_length
if chunk_size is not None and max_token_length:
return min(chunk_size, max_token_length)
return chunk_size or DEFAULT_CHUNK_SIZE
def split(self, doc: TextChunk | Document) -> list[TextChunk]:
if isinstance(doc, Document):
node = doc
else:
node = Document(
id_=doc.doc_id,
text=doc.text,
metadata=doc.metadata,
)
return self._splitter.get_nodes_from_documents([node])