from typing import List, Optional
from transformers import AutoTokenizer
class LuminaMGPT2Tokenizer:
def __init__(self, **config):
self.tokenizer = AutoTokenizer.from_pretrained(**config)
self.bos_id: int = self.tokenizer.bos_token_id
if self.bos_id is None:
self.bos_id = self.tokenizer.eos_token_id
self.eos_id: int = self.tokenizer.eos_token_id
self._probe_tokenizer_style()
def encode(self, s: str, bos: bool, eos: bool) -> List[int]:
t = self.tokenizer.encode(s, truncation=False, add_special_tokens=False)
if bos:
t = [self.bos_id] + t
if eos:
t = t + [self.eos_id]
return t
def encode_wo_prefix_space(self, s: str):
if self.need_space_before_segment:
return self.encode(s, bos=False, eos=False)
else:
l_prefix = ["@", "\n", "\\", "=", ">", "`"]
for prefix in l_prefix:
prefix_tokens = self.encode(prefix, bos=False, eos=False)
cat_tokens = self.encode(prefix + s, bos=False, eos=False)
if cat_tokens[: len(prefix_tokens)] == prefix_tokens:
return cat_tokens[len(prefix_tokens):]
raise NotImplementedError(
f"All prefixes are merged into {s} during tokenization,"
f"This is weird behavior, please open an issue to report this problem",
)
def _probe_tokenizer_style(self):
"""
Given a sentence, e.g. "Hi my darling", some tokenizers (e.g. LLaMA's) will pose the following behavior:
>>> # leading characters will be treated as if there were an " " in the beginning
>>> tokenizer.encode("Hi my darling") == tokenizer.encode("Hi") + tokenizer.encode("my darling")
>>> # leading space " " is redundant and should not be added
>>> tokenizer.encode("Hi my darling") != tokenizer.encode("Hi") + tokenizer.encode(" my darling")
However, some others (e.g. InternLM's) will behave differently:
>>> # leading space " " has to be explicitly added
>>> tokenizer.encode("Hi my darling") == tokenizer.encode("Hi") + tokenizer.encode(" my darling")
Knowing which style the tokenizer takes is necessary when tokenzing a segment cut from the complete
text, so that the result is the same as the corresponding part in the tokenized original text.
"""
sentence1 = self.encode("Hi my darling", bos=False, eos=False)
sentence2 = self.encode("my darling", bos=False, eos=False)
if sentence1[-len(sentence2):] == sentence2:
self.need_space_before_segment = False
else:
sentence3 = self.encode(" my darling", bos=False, eos=False)
if sentence1[-len(sentence3):] != sentence3:
raise AssertionError
self.need_space_before_segment = True