from typing import Callable, List, Tuple
from openjiuwen.core.common.logging import logger
from openjiuwen.core.retrieval.indexing.processor.splitter.base import Splitter
from pysbd import Segmenter
class SentenceSplitter(Splitter):
def __init__(
self,
tokenizer: Callable,
chunk_size: int,
chunk_overlap: int,
lan: str = "zh",
):
"""
Initialize sentence splitter
Args:
tokenizer: Tokenizer, must have encode and decode methods
chunk_size: Chunk size (number of tokens)
chunk_overlap: Chunk overlap size (number of tokens)
lan: Language code, defaults to "zh" (Chinese)
"""
super().__init__(
tokenizer=tokenizer,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
)
self.default_lan = lan
self.seg = None
self._span_recovery_failures = {}
@staticmethod
def _detect_chinese(text: str) -> str:
"""
Detect if text is primarily Chinese or English based on character distribution.
Args:
text: Text to analyze
Returns:
"zh" if more than 10% of characters are Chinese, otherwise "en"
"""
if not text:
return "en"
chinese_count = 0
total_chars = len(text)
for char in text:
code_point = ord(char)
if 0x4E00 <= code_point <= 0x9FFF:
chinese_count += 1
chinese_ratio = chinese_count / total_chars if total_chars > 0 else 0
return "zh" if chinese_ratio > 0.1 else "en"
def __call__(self, doc: str) -> List[Tuple[str, int, int]]:
"""
Split document into sentence-level chunks
Args:
doc: Document text to be split
Returns:
List of chunks, each element is (text, start char position, end char position)
"""
if not doc or not doc.strip():
return []
detected_lan = self._detect_chinese(doc)
self.seg = Segmenter(language=detected_lan, clean=False)
sentences_with_spans = self._sentences_with_spans(doc)
chunks: List[Tuple[str, int, int]] = []
cur_sents: List[Tuple[str, int, int, int]] = []
for sent_text, sent_start, sent_end, sent_len in sentences_with_spans:
if not sent_text.strip():
continue
if sent_len > self.chunk_size:
chunks, cur_sents = self._flush(chunks, cur_sents)
chunks.append((sent_text, sent_start, sent_end))
continue
cur_token_count = sum(s[3] for s in cur_sents)
if cur_token_count + sent_len <= self.chunk_size:
cur_sents.append((sent_text, sent_start, sent_end, sent_len))
else:
chunks, cur_sents = self._flush(chunks, cur_sents)
cur_sents = [(sent_text, sent_start, sent_end, sent_len)]
chunks, _ = self._flush(chunks, cur_sents)
logger.info(
f"Computed the following sentence-level chunks: {len(chunks)} chunks"
)
return chunks
def _sentences_with_spans(self, text: str) -> List[Tuple[str, int, int, int]]:
sentences = self.seg.segment(text)
used_spans = set()
spans = []
for sent in sentences:
if not sent.strip():
continue
sent_tokens = len(self.tokenizer_enc(sent))
idx = 0
while True:
idx = text.find(sent, idx)
if idx == -1:
logger.warning(f"Span recovery failed for: {repr(sent[:30])}...")
break
span = (idx, idx + len(sent))
if span not in used_spans:
used_spans.add(span)
spans.append((sent, span[0], span[1], sent_tokens))
break
idx += 1
return spans
def _flush(
self,
chunks: List[Tuple[str, int, int]],
cur_sents: List[Tuple[str, int, int, int]],
) -> Tuple[List[Tuple[str, int, int]], List[Tuple[str, int, int, int]]]:
if not cur_sents:
return chunks, []
chunk_text = "".join(s[0] for s in cur_sents)
start_char = cur_sents[0][1]
end_char = cur_sents[-1][2]
chunks.append((chunk_text, start_char, end_char))
next_cur_sents = []
if self.chunk_overlap > 0 and len(cur_sents) > 1:
overlap_tokens = 0
overlap_sents = []
for sent_text, s_start, s_end, sent_toks in reversed(cur_sents):
if overlap_tokens + sent_toks <= self.chunk_overlap:
overlap_sents.append((sent_text, s_start, s_end, sent_toks))
overlap_tokens += sent_toks
else:
break
next_cur_sents = list(reversed(overlap_sents))
return chunks, next_cur_sents