from typing import List
import sentencepiece
import torch
TOKEN_FLAG_MAPPINGS = {
"_bot_id": "<|BOT|>",
"_eot_id": "<|EOT|>",
"_call_start_id": "<|CALL_START|>",
"_call_end_id": "<|CALL_END|>",
"_think_start_id": "<|THINK_START|>",
"_think_end_id": "<|THINK_END|>",
"_mask_start_id": "<|MASK_1e69f|>",
"_mask_end_id": "<|UNMASK_1e69f|>",
}
class StepChatTokenizer:
"""Step Chat Tokenizer"""
def __init__(
self, pretrained_model_name_or_path,
model_max_length=320,
**kwargs
):
self._tokenizer = sentencepiece.SentencePieceProcessor(model_file=pretrained_model_name_or_path)
self.model_max_length = model_max_length
self._vocab = {}
self._inv_vocab = {}
self._special_tokens = {}
self._inv_special_tokens = {}
self._t5_tokens = []
for idx in range(self._tokenizer.get_piece_size()):
text = self._tokenizer.id_to_piece(idx)
self._inv_vocab[idx] = text
self._vocab[text] = idx
if self._tokenizer.is_control(idx) or self._tokenizer.is_unknown(idx):
self._special_tokens[text] = idx
self._inv_special_tokens[idx] = text
self._unk_id = self._tokenizer.unk_id()
self._bos_id = self._tokenizer.bos_id()
self._eos_id = self._tokenizer.eos_id()
self.check_tokens()
for key, value in TOKEN_FLAG_MAPPINGS.items():
setattr(self, key, self._tokenizer.piece_to_id(value))
self._underline_id = self._tokenizer.piece_to_id("\u2581")
with open("./token.txt", 'w') as f:
f.write(str(self.__dict__))
@property
def get_tokenizer(self):
return self._tokenizer
@property
def vocab(self):
return self._vocab
@property
def inv_vocab(self):
return self._inv_vocab
@property
def vocab_size(self):
return self._tokenizer.vocab_size()
def tokenize(self, text: str) -> List[int]:
return self._tokenizer.encode_as_ids(text)
def detokenize(self, token_ids: List[int]) -> str:
return self._tokenizer.decode_ids(token_ids)
def check_tokens(self):
for key in ["_bot_id", "_eot_id", "_call_start_id", "_call_end_id", "_think_start_id", "_think_end_id",
"_mask_start_id", "_mask_end_id"]:
token = TOKEN_FLAG_MAPPINGS[key]
if token not in self._vocab:
raise Exception(f"Token '{token}' not found in tokenizer")
for key in ["_bot_id", "_eot_id", "_call_start_id", "_call_end_id", "_think_start_id", "_think_end_id"]:
token = TOKEN_FLAG_MAPPINGS[key]
if token not in self._special_tokens:
raise Exception(f"Token '{token}' is not a special token")
class Tokens:
def __init__(self, input_ids, attention_mask) -> None:
self.input_ids = input_ids
self.attention_mask = attention_mask
def to(self, device):
self.input_ids = self.input_ids.to(device)
self.attention_mask = self.attention_mask.to(device)
return self
def __getitem__(self, item):
if hasattr(self, item):
return getattr(self, item)
else:
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'")
class WrappedStepChatTokenizer(StepChatTokenizer):
def __call__(self, text, max_length=320, **kwargs):
self.BOS = 1
self.EOS = 2
self.PAD = 2
out_tokens = []
attn_mask = []
if not isinstance(text, list):
text = [text]
if len(text) == 0:
part_tokens = [self.BOS] + [self.EOS]
valid_size = len(part_tokens)
if len(part_tokens) < max_length:
part_tokens += [self.PAD] * (max_length - valid_size)
out_tokens.append(part_tokens)
attn_mask.append([1] * valid_size + [0] * (max_length - valid_size))
else:
for part in text:
part_tokens = self.tokenize(part)
part_tokens = part_tokens[:(max_length - 2)]
part_tokens = [self.BOS] + part_tokens + [self.EOS]
valid_size = len(part_tokens)
if len(part_tokens) < max_length:
part_tokens += [self.PAD] * (max_length - valid_size)
out_tokens.append(part_tokens)
attn_mask.append([1] * valid_size + [0] * (max_length - valid_size))
out_tokens = torch.tensor(out_tokens, dtype=torch.long)
attn_mask = torch.tensor(attn_mask, dtype=torch.long)
return Tokens(out_tokens, attn_mask)