package magic.tokenizer
internal import magic.core.tokenizer.Tokenizer
internal import magic.utils.{readFile, StringExt, exists}
internal import magic.core.tokenizer.*
internal import std.collection.*
internal import std.regex.*
internal import std.fs.*
internal import encoding.json.*
internal import std.collection.*
internal import std.fs.*
internal import std.io.*
internal import std.regex.*
internal import serialization.serialization.*
internal import std.math.pow
public let UINT32_MAX = UInt32(4294967295)
public interface JsonDeserializable<T> <: Serializable<T> {
func serialize(): DataModel {
throw UnsupportedException()
}
static func fromJson(str: String) {
let jv = JsonValue.fromStr(str)
return deserialize(DataModel.fromJson(jv))
}
}
struct Token <: JsonDeserializable<Token> {
let _id: UInt32
let _content: String
let _singleWord: Bool
let _lstrip: Bool
let _rstrip: Bool
let _normalized: Bool
let _special: Bool
private init(
id: UInt32,
content: String,
singleWord: Bool,
lstrip: Bool,
rstrip: Bool,
normalized: Bool,
special: Bool
) {
this._id = id
this._content = content
this._singleWord = singleWord
this._lstrip = lstrip
this._rstrip = rstrip
this._normalized = normalized
this._special = special
}
public static func deserialize(dm: DataModel): Token {
var dms = match (dm) {
case data: DataModelStruct => data
case _ => throw Exception("this data is not DataModelStruct")
}
let id = UInt32.deserialize(dms.get("id"))
let content = String.deserialize(dms.get("content"))
let singleWord = Bool.deserialize(dms.get("single_word"))
let lstrip = Bool.deserialize(dms.get("lstrip"))
let rstrip = Bool.deserialize(dms.get("rstrip"))
let normalized = Bool.deserialize(dms.get("normalized"))
let special = Bool.deserialize(dms.get("special"))
return Token(id, content, singleWord, lstrip, rstrip, normalized, special)
}
}
struct Normalizer <: JsonDeserializable<Normalizer> {
let _type: String
private init(nType: String) {
this._type = nType
}
public static func deserialize(dm: DataModel): Normalizer {
var dms = match (dm) {
case data: DataModelStruct => data
case _ => throw Exception("this data is not DataModelStruct")
}
let nType = String.deserialize(dms.get("type"))
return Normalizer(nType)
}
}
struct Pattern <: JsonDeserializable<Pattern> {
let _regex: String
private init(pRegex: String) {
this._regex = pRegex
}
public static func deserialize(dm: DataModel): Pattern {
var dms = match (dm) {
case data: DataModelStruct => data
case _ => throw Exception("this data is not DataModelStruct")
}
let regex = String.deserialize(dms.get("Regex"))
return Pattern(regex)
}
}
struct Process <: JsonDeserializable<Process> {
// Union ByteLevel and SplitTokenizer
let _type: String
// use for ByteLevel: decoder/post_processor
let _addPrefixSpace: Option<Bool>
let _trimOffsets: Option<Bool>
let _useRegex: Option<Bool>
// actived when split tokenizer
let _pattern: Option<Pattern>
let _behavior: Option<String>
let _invert: Option<Bool>
// construct for decode/post_preprossor
private init(
pType: String,
addPrefixSpace: Bool,
trimOffsets: Bool,
useRegex: Bool
) {
this._type = pType
if (this._type != "ByteLevel") {
throw Exception("this construction function only suppport ByteLevel type");
}
this._addPrefixSpace = Some(addPrefixSpace)
this._trimOffsets = Some(trimOffsets)
this._useRegex = Some(useRegex)
this._pattern = None
this._behavior = None
this._invert = None
}
private init(
pType: String,
pattern: Pattern,
behavior: String,
invert: Bool
) {
this._type = pType
if (pType != "Split") {
throw Exception("this construction function only support Split type")
}
this._addPrefixSpace = None
this._trimOffsets = None
this._useRegex = None
this._pattern = Some(pattern)
this._behavior = Some(behavior)
this._invert = Some(invert)
}
public static func deserialize(dm: DataModel): Process {
var dms = match (dm) {
case data: DataModelStruct => data
case _ => throw Exception("this data is not DataModelStruct")
}
let _type = String.deserialize(dms.get("type"))
match (_type) {
case "ByteLevel" =>
let addPrefixSpace = Bool.deserialize(dms.get("add_prefix_space"))
let trimOffsets = Bool.deserialize(dms.get("trim_offsets"))
let useRegex = Bool.deserialize(dms.get("use_regex"))
Process(
_type,
addPrefixSpace,
trimOffsets,
useRegex,
);
case "Split" =>
let pattern = Pattern.deserialize(dms.get("pattern"))
let behavior = String.deserialize(dms.get("behavior"))
let invert = Bool.deserialize(dms.get("invert"))
Process(
_type,
pattern,
behavior,
invert
)
case _ => throw Exception("unkonw process type ${_type}}")
}
}
}
struct PreTokenizer <: JsonDeserializable<PreTokenizer> {
let _type: String
let _pretokenizers: ArrayList<Process>
private init(pType: String, pretokenizers: ArrayList<Process>) {
this._type = pType
this._pretokenizers = pretokenizers
}
private init() {
this._type = ""
this._pretokenizers = ArrayList<Process>()
}
public static func deserialize(dm: DataModel): PreTokenizer {
var dms = match (dm) {
case data: DataModelStruct => data
case _ => throw Exception("this data is not DataModelStruct")
}
let _type = String.deserialize(dms.get("type"))
let _pretokenizers = ArrayList<Process>.deserialize(dms.get("pretokenizers"))
return PreTokenizer(_type, _pretokenizers)
}
}
struct Model <: JsonDeserializable<Model> {
let _type: String
// perform no merges, so the result will just be characters.
let _dropout: Option<Float32>
// The unknown token to be used when we encounter an unknown char
let _unkToken: Option<String>
// An optional prefix to use on any subword that exist only behind another one
let _continuingSubwordPrefix: String
// An optional suffix to caracterize and end-of-word subword
let _endOfWordSuffix: String
// Do multiple unk tokens get fused
let _fuseUnk: Bool
// Byte fallback from sentence pieces, instead of UNK, uses `"<0x00>"`
// for each byte in the unk token
let _byteFallback: Bool
let _vocab: HashMap<String, UInt32>
let _merges: ArrayList<String>
private init(
_type: String,
dropout: Option<Float32>,
vocab: HashMap<String, UInt32>,
merges: ArrayList<String>,
unkToken!: Option<String> = None,
continuingSubwordPrefix!: String = "",
endOfWordSuffix!: String = "",
fuseUnk!: Bool = false,
byteFallback!: Bool = false
) {
this._type = _type
if (dropout.isSome()) {
let dropout_value = dropout.getOrThrow();
if (dropout_value < 0.0 || dropout_value > 1.0) {
throw Exception("dropout can only between 0~1")
}
}
this._dropout = dropout
this._unkToken = unkToken
this._continuingSubwordPrefix = continuingSubwordPrefix
this._endOfWordSuffix = endOfWordSuffix
this._fuseUnk = fuseUnk
this._byteFallback = byteFallback
this._vocab = vocab
this._merges = merges
}
public static func deserialize(dm: DataModel): Model {
var dms = match (dm) {
case data: DataModelStruct => data
case _ => throw Exception("this data is not DataModelStruct")
}
let _type = String.deserialize(dms.get("type"))
let dropout = Option<Float32>.deserialize(dms.get("dropout"))
let vocab = HashMap<String, UInt32>.deserialize(dms.get("vocab"))
let merges = ArrayList<String>.deserialize(dms.get("merges"))
let unkToken = Option<String>.deserialize(dms.get("unk_token"))
let continuingSubwordPrefix = Option<String>.deserialize(dms.get("continuing_subword_prefix"))
let endOfWordSuffix = Option<String>.deserialize(dms.get("end_of_word_suffix"))
let fuseUnk = Bool.deserialize(dms.get("fuse_unk"))
let byteFallback = Bool.deserialize(dms.get("byte_fallback"))
return Model(_type, dropout, vocab, merges, unkToken: unkToken,
continuingSubwordPrefix: continuingSubwordPrefix.getOrDefault({=> ""}),
endOfWordSuffix: endOfWordSuffix.getOrDefault({=> ""}), byteFallback: byteFallback)
}
}
public struct TokenizerJson <: JsonDeserializable<TokenizerJson> {
let _version: String
let _truncation: Option<String>
let _padding: Option<String>
let _addedTokens: ArrayList<Token>
let _normalizer: Normalizer
let _pre_tokenizer: PreTokenizer
let _post_processor: Process
let _decoder: Process
let _model: Model
init(
version: String,
truncation: Option<String>,
padding: Option<String>,
addedTokens: ArrayList<Token>,
normalizer: Normalizer,
pre_tokenizer: PreTokenizer,
post_processor: Process,
decoder: Process,
model: Model
) {
this._version = version
this._truncation = truncation
this._padding = padding
this._addedTokens = addedTokens
this._normalizer = normalizer
this._pre_tokenizer = pre_tokenizer
this._post_processor = post_processor
this._decoder = decoder
this._model = model
}
public static func deserialize(dm: DataModel): TokenizerJson {
var dms = match (dm) {
case data: DataModelStruct => data
case _ => throw Exception("this data is not DataModelStruct")
}
let version = String.deserialize(dms.get("version"))
let truncation = Option<String>.deserialize(dms.get("truncation"))
let padding = Option<String>.deserialize(dms.get("padding"))
let addedTokens = ArrayList<Token>.deserialize(dms.get("added_tokens"))
let normalizer = Normalizer.deserialize(dms.get("normalizer"))
let preTokenizer = PreTokenizer.deserialize(dms.get("pre_tokenizer"))
let postProcessor = Option<Process>.deserialize(dms.get("post_processor"))
let decoder = Option<Process>.deserialize(dms.get("decoder"))
let model = Option<Model>.deserialize(dms.get("model"))
return TokenizerJson(version, truncation, padding, addedTokens, normalizer, preTokenizer,
postProcessor.getOrThrow(), decoder.getOrThrow(), model.getOrThrow())
}
}
public struct BPETokenizerConfig <: JsonDeserializable<BPETokenizerConfig> {
let _addBosToken: Bool
let _addEosToken: Bool
let _bosToken: Option<String>
let _eosToken: Option<String>
let _padToken: Option<String>
let _unkToken: Option<String>
let _modelMaxLength: UInt32
let _tokenizerClass: String
private init(
addBosToken: Bool,
addEosToken: Bool,
bosToken: Option<String>,
eosToken: Option<String>,
padToken: Option<String>,
unkToken: Option<String>,
modelMaxLength: UInt32,
tokenizerClass: String
) {
this._addBosToken = addBosToken
this._addEosToken = addEosToken
this._bosToken = bosToken
this._eosToken = eosToken
this._padToken = padToken
this._unkToken = unkToken
this._modelMaxLength = modelMaxLength
this._tokenizerClass = tokenizerClass
}
public static func deserialize(dm: DataModel): BPETokenizerConfig {
var dms = match (dm) {
case data: DataModelStruct => data
case _ => throw Exception("this data is not DataModelStruct")
}
let addBosToken = Option<Bool>.deserialize(dms.get("add_bos_token"))
let addEosToken = Option<Bool>.deserialize(dms.get("add_eos_token"))
let _ = Option<Bool>.deserialize(dms.get("_"))
let bosToken: Option<String> = match (dms.get("bos_token")) {
case data: DataModelString => Option<String>.deserialize(data)
case data: DataModelStruct => match (data.get("content")) {
case innerData: DataModelString => Option<String>.deserialize(innerData)
case _ => None
}
case _ => None
}
let eosToken: Option<String> = match (dms.get("eos_token")) {
case data: DataModelString => Option<String>.deserialize(data)
case data: DataModelStruct => match (data.get("content")) {
case innerData: DataModelString => Option<String>.deserialize(innerData)
case _ => None
}
case _ => None
}
let padToken: Option<String> = match (dms.get("pad_token")) {
case data: DataModelString => Option<String>.deserialize(data)
case data: DataModelStruct => match (data.get("content")) {
case innerData: DataModelString => Option<String>.deserialize(innerData)
case _ => None
}
case _ => None
}
let unkToken: Option<String> = match (dms.get("unk_token")) {
case data: DataModelString => Option<String>.deserialize(data)
case data: DataModelStruct => match (data.get("content")) {
case innerData: DataModelString => Option<String>.deserialize(innerData)
case _ => None
}
case _ => None
}
let modelMaxLength = UInt32.deserialize(dms.get("model_max_length"))
let tokenizerClass = String.deserialize(dms.get("tokenizer_class"))
return BPETokenizerConfig(addBosToken.getOrDefault({=> false}), addEosToken.getOrDefault({=> false}),
bosToken, eosToken, padToken, unkToken, modelMaxLength, tokenizerClass)
}
}
public class Pair<T> <: Hashable & Equatable<Pair<T>> where T <: Hashable & Equatable<T> {
let _left: T
let _right: T
public init(left: T, right: T) {
this._left = left
this._right = right
}
public prop left: T {
get() {
this._left
}
}
public prop right: T {
get() {
this._right
}
}
public func hashCode(): Int64 {
let hash1: Int64 = this._left.hashCode()
let hash2: Int64 = this._right.hashCode()
// If hash1 == hash2, their XOR is zero.
if (hash1 != hash2) {
return hash1 ^ hash2
} else {
return hash1
}
}
public operator func ==(other: Pair<T>): Bool {
if (this._left == other._left && this._right == other._right) {
return true;
} else {
return false;
}
}
@When[cjc_version < "0.56.4"]
public operator func !=(other: Pair<T>): Bool {
!(this == other)
}
}
type BPERanks = HashMap<String, UInt32>
public type MergeMap = HashMap<Pair<UInt32>, (UInt32, UInt32)>
public abstract class AbstractBPETokenizer <: Tokenizer {
// The _vocabulary assigns a number to each token.
var _vocab: HashMap<String, UInt32> = HashMap()
// Reversed _vocabulary, to rebuild sentences.
var _vocabR: HashMap<UInt32, String> = HashMap()
// Contains the mapping between Pairs and their (rank, new_id).
var _merges: MergeMap = MergeMap()
var _bpeRanks: BPERanks = BPERanks()
var _unkTokenId: Option<UInt32> = None
/// Whether or not to direct output words if they are part of the _vocab.
var _ignoreMerges: Bool = false;
// cache
var _cache: HashMap<String, Array<String>> = HashMap()
// the key is uniformed special token, and value is the origin token
// See - func unexpectedNormalize and unexpectedRestore
var _specialTokens: HashMap<String, String> = HashMap()
var _specialIds: HashSet<UInt32> = HashSet()
var _byte2rune: HashMap<UInt8, Rune> = HashMap()
var _rune2byte: HashMap<Rune, UInt8> = HashMap()
var _specialRegex: Regex = Regex("")
// use fix preSplitRegex as default
// same reason as func unexpectedNormalize and unexpectedRestore
var _preSplitRegex: Regex = Regex(
"('s|'t|'re|'ve|'m|'ll|'d| ?[[:punct:]]?[[:alpha:]]+| ?[[:punct:]]?[[:digit:]]+| ?[[:punct:]]?[^\\s\\w]+|\\s+)")
var _addBosToken: Bool = false
var _addEosToken: Bool = false
var _bosToken: Option<String> = None
var _eosToken: Option<String> = None
private func uniformSpecial(token: Token): String {
var content = unexpectedNormalize(token._content)
if (token._lstrip) {
content = content.trimAsciiLeft()
}
if (token._rstrip) {
content = content.trimAsciiRight()
}
return content
}
/**
* Due to the weak support for Regex in the current version, replace the common word segmentation symbols.
* Sometimes it makes the encoded result a little different when there are special UTF - 8 characters in the input.
* This function should be removed in future versions.
*/
func unexpectedNormalize(token: String): String {
return token.replace("▁", "_").replace("|", "|")
}
/**
* Due to the weak support for Regex in the current version, replace the common word segmentation symbols.
* This function is opposed to the function "unexpectedNormalize". it makes tokens go back to unnormalized.
* This function should be removed in the future.
*/
func unexpectedRestore(token: String): String {
return this._specialTokens.get(token).getOrDefault({=> token})
}
func buildSepcialRegex() {
var sb: StringBuilder = StringBuilder()
sb.append("(")
for ((i, spToken) in enumerate(this._specialTokens.keys())) {
if (i < this._specialTokens.size - 1) {
sb.append(spToken.replace("|", "\\|") + "|")
} else {
sb.append(spToken.replace("|", "\\|") + ")")
}
}
this._specialRegex = Regex(sb.toString())
}
public func decode(tokens: Array<UInt32>): String {
let strList = ArrayList<String>()
for ((i, tokenId) in enumerate(tokens)) {
strList.append(this._vocabR.get(tokenId).getOrThrow())
}
let datas = ArrayList<UInt8>()
for (str in strList) {
for (rune in str.toRuneArray()) {
let byte: UInt8 = this._rune2byte[rune];
datas.append(byte)
}
}
return String.fromUtf8(datas.toArray())
}
public func encode(input: String): Array<UInt32> {
// find special tokens
var uniformed = unexpectedNormalize(input)
let specialTokenMap = HashMap<Int64, UInt32>()
let specialTokenIndexes = HashSet<Pair<Int64>>()
let specMathDatas = this._specialRegex.matcher(uniformed).findAll() ?? Array<MatchData>()
var text = uniformed;
// replace special tokens with empty blank
for (specMd in specMathDatas) {
let specialToken = specMd.matchStr()
specialTokenMap.put(
specMd.matchPosition().start,
this._vocab.get(specialToken).getOrThrow()
)
specialTokenIndexes.put(Pair(specMd.matchPosition().start, specMd.matchPosition().end))
var space = "";
var i = 0;
while (i < specialToken.size) {
space += " "
i++;
}
text = text.replace(specialToken, space)
}
let matchDatas = this._preSplitRegex.matcher(text).findAll() ?? Array<MatchData>()
var ids = ArrayList<UInt32>()
if (this._addBosToken) {
if (let Some(bosToken) <- this._bosToken) {
if (let Some(bosId) <- this._vocab.get(bosToken)) {
ids.append(bosId)
}
}
}
for (matchData in matchDatas) {
let ms = matchData.matchPosition().start;
let me = matchData.matchPosition().end;
let orgStr = uniformed[ms..me]
// find special token
if (this._specialTokens.contains(orgStr)) {
ids.append(this._vocab[orgStr])
continue
}
let matchedStr = matchData.matchStr()
var runeLevelStr: String
var runeList = ArrayList<String>()
if (this._vocab.contains(matchedStr)) {
runeList.append(matchedStr)
runeLevelStr = matchedStr
} else {
for (b in matchedStr.toArray()) {
let rune = this._byte2rune[b];
runeList.append(rune.toString())
}
runeLevelStr = String.join(runeList.toArray())
}
if (this._cache.contains(runeLevelStr)) {
let cacheLst = this._cache[runeLevelStr];
for (cache in cacheLst) {
ids.append(this._vocab[cache])
}
continue;
}
let mergeVocabs = this.bpe(runeList.toArray())
// update cache
this._cache.put(runeLevelStr, mergeVocabs)
for (mv in mergeVocabs) {
try {
let id = this._vocab.get(mv).getOrDefault({=> this._unkTokenId.getOrThrow()})
ids.append(id)
} catch (e: Exception) {
e.printStackTrace()
}
}
}
if (this._addEosToken) {
if (let Some(eosToken) <- this._eosToken) {
if (let Some(eosId) <- this._vocab.get(eosToken)) {
ids.append(eosId)
}
}
}
return ids.toArray()
}
/**
* merge tokens and find the best tokens to get
*/
private func bpe(tokens: Array<String>): Array<String> {
if (tokens.size == 0) {
return []
}
var finalTokens = ArrayList<String>(tokens)
// records indices in pairs that were merged.
while (finalTokens.size > 1) {
let pairs: ArrayList<Pair<String>> = this.buildAdjacentConcatenations(finalTokens)
var minScore: UInt32 = UINT32_MAX;
var toMerge: Int64 = -1;
for ((i, pair) in enumerate(pairs)) {
// if pair i is not merged.
let score = this._bpeRanks.get(pair.left + pair.right).getOrDefault({=> UINT32_MAX})
if (score < minScore) {
minScore = score
toMerge = i
}
}
if (toMerge == -1) {
break;
}
let bestTokens: ArrayList<String> = ArrayList<String>()
let left = pairs[toMerge].left;
let right = pairs[toMerge].right;
// find all meged tokens and merge
var i = 0
while (i < finalTokens.size) {
let j = find<String>(finalTokens, i, pairs[toMerge])
if (j >= 0) {
bestTokens.appendAll(finalTokens[i..j])
bestTokens.append(left + right)
i = j + 2
} else {
bestTokens.appendAll(finalTokens[i..])
break;
}
}
finalTokens = bestTokens
}
return finalTokens.toArray()
}
private func find<T>(arr: ArrayList<T>, index: Int64, target: Pair<T>): Int64 where T <: Hashable & Equatable<T> {
for (i in index..arr.size - 1) {
if (arr[i] == target.left && arr[i + 1] == target.right) {
return i;
}
}
return -1;
}
/**
* ['a','b','c'] -> [('a','b'),('b','c')]
*/
private func buildAdjacentConcatenations(tokens: ArrayList<String>): ArrayList<Pair<String>> {
var pairs: ArrayList<Pair<String>> = ArrayList<Pair<String>>()
if (tokens.size <= 1) {
return pairs
}
for (i in 0..tokens.size - 1) {
pairs.append(Pair<String>(tokens[i].toString(), tokens[i + 1].toString()))
}
return pairs;
}
}