02f94c82创建于 2025年5月14日历史提交
package tokenizer
import encoding.json.stream.JsonReader
internal import std.collection.{ArrayList, HashMap, HashSet}
import std.fs.{File, Path, OpenMode, exists}
import std.io.{readToEnd, ByteBuffer}
import std.regex.*
import std.math.*

public type BPERanks = HashMap<HashPairString, UInt32>
public type MergeMap = HashMap<HashPairUInt32, (UInt32, UInt32)>

public class HuggingfaceTokenizer <: Tokenizer {
  // The vocabulary assigns a number to each token.
  public var vocab: HashMap<String, UInt32> = HashMap<String, UInt32>()
  public var vocab_r: HashMap<UInt32, String> = HashMap<UInt32, String>()
  public var merges: MergeMap = MergeMap()
  public var dropout: Option<Float32> = None
  public var unk_token: Option<String> = None
  public var unk_token_id: Option<UInt32> = None
  public var continuing_subword_prefix: Option<String> = None
  public var end_of_word_suffix: Option<String> = None
  public var fuse_unk: Bool = false
  public var byte_fallback: Bool = false
  public var ignore_merges: Bool = false
  // cache
  private let cache: HashMap<String, Array<String>> = HashMap<String, Array<String>>()
  private var bpe_ranks_: BPERanks = BPERanks()
  private let special_id_set: HashSet<String> = HashSet<String>()
  private var byte2rune: HashMap<UInt8, Rune> = HashMap<UInt8, Rune>()
  private var rune2byte: HashMap<Rune, UInt8> = HashMap<Rune, UInt8>()
  

  public init(
    vocab! : HashMap<String, UInt32> = HashMap<String, UInt32>(),
    merge_list!: ArrayList<(String, String)> = ArrayList<(String, String)>(),
    dropout!: Option<Float32> = None,
    unk_token!: Option<String> = None,
    continuing_subword_prefix!: Option<String> = None,
    end_of_word_suffix!: Option<String> = None,
    fuse_unk!: Bool = false,
    byte_fallback!: Bool = false,
    ignore_merges!: Bool = false
  ) {
    this.vocab = vocab
    for ((token, id) in this.vocab) {
      this.vocab_r.add(id, token)
    }
    (this.byte2rune, this.rune2byte) = this.get_byte_char()
    let prefix_len = continuing_subword_prefix.getOrDefault({=>""}).size
    var i: UInt32 = 0
    for ((a, b) in merge_list) {
      let a_id: UInt32 = this.vocab.get(a).getOrThrow()
      let b_id: UInt32 = this.vocab.get(b).getOrThrow()
      let new_token: String = "${a}${b[prefix_len..]}"
      let new_id: UInt32 = this.vocab.get(new_token).getOrThrow()
      this.merges.add(HashPairUInt32((a_id, b_id)), (i, new_id))
      this.bpe_ranks_.add(HashPairString((a, b)), i)
      i++
    }
    this.merges = merges
    this.dropout = dropout
    this.unk_token = unk_token
    this.continuing_subword_prefix = continuing_subword_prefix
    this.end_of_word_suffix = end_of_word_suffix
    this.fuse_unk = fuse_unk
    this.byte_fallback = byte_fallback
    this.ignore_merges = ignore_merges
  }

  public func get_byte_char(): (HashMap<UInt8, Rune>, HashMap<Rune, UInt8>) {
    var bs: ArrayList<UInt8> = ArrayList<UInt8>()
    var cs: ArrayList<UInt32> = ArrayList<UInt32>()
    var ranges: ArrayList<(Rune, Rune)> = ArrayList<(Rune, Rune)>()
    ranges.add((r'.', r'.')) // Example placeholder; replace with actual values
    ranges.add((r'!',r'~'))
    ranges.add((r"\u{A1}", r"\u{AC}"))
    ranges.add((r"\u{AE}", r"\u{FF}"))
    for ((start, end) in ranges) { 
        for (x in UInt32(start)..=UInt32(end)) { 
            bs.add(UInt8(x))  
            cs.add(x)  
        } 
    }
    var n: UInt32 = 0
    for (b in 0_u8..=255_u8) {
      if (!bs.contains(b)) {
        bs.add(b)
        cs.add(UInt32(pow(2.0, 8)) + n)
        n += 1
      }
    }
    var result: HashMap<UInt8, Rune> = HashMap<UInt8, Rune>()
    var result_r: HashMap<Rune, UInt8> = HashMap<Rune, UInt8>()
    for ((k, v) in bs.iterator().zip(cs.iterator())) {
      result.add(k, Rune(v))
      result_r.add(Rune(v), k)
    }
    return (result, result_r)
  }

  public override func load_vocab(file_path!: String = "", buffer!: Array<UInt8> = Array<UInt8>()) {
    let json_buffer: Array<UInt8> = if (!exists(file_path) && buffer.size == 0) {
      throw Exception("${file_path} not exists and buffer is empty")
    } else if (exists(file_path)) {
      let file = File(file_path, OpenMode.Read)
      let temp_buffer = file |> readToEnd
      file.close()
      temp_buffer
    } else {
      buffer
    }
    var byte_stream = ByteBuffer()
    byte_stream.write(json_buffer)
    let json_reader = JsonReader(byte_stream)
    let tokenizer_config = TokenizerJson.fromJson(json_reader)
    (this.byte2rune, this.rune2byte) = this.get_byte_char()
    this.vocab = tokenizer_config.model.vocab
    for (add_token in tokenizer_config.added_tokens) {
      this.vocab.add(add_token.content, add_token.id)
      this.special_id_set.add(add_token.content)
    }
    if (this.unk_token.isSome()) {
      this.unk_token_id = this.vocab.get(this.unk_token.getOrThrow())
    }
    for ((token, id) in tokenizer_config.model.vocab) {
      this.vocab_r.add(id, token)
    }
    var str_list: Array<String> = Array<String>()
    var merge_list: ArrayList<(String, String)> = ArrayList<(String, String)>()
    for (line in tokenizer_config.model.merges) {
      str_list = line.split(" ")
      if (str_list.size != 2) {
        throw Exception("merge line size must be 2")
      }
      merge_list.add((str_list[0], str_list[1]))
    }
    let prefix_len = tokenizer_config.model.continuing_subword_prefix.size
    var i: UInt32 = 0
    for ((a, b) in merge_list) {
      let a_id: UInt32 = this.vocab.get(a).getOrThrow()
      let b_id: UInt32 = this.vocab.get(b).getOrThrow()
      let new_token: String = "${a}${b[prefix_len..]}"
      let new_id: UInt32 = this.vocab.get(new_token).getOrThrow()
      this.merges.add(HashPairUInt32((a_id, b_id)), (i, new_id))
      this.bpe_ranks_.add(HashPairString((a, b)), i)
      i++
    }
    this.dropout = tokenizer_config.model.dropout
    this.unk_token = tokenizer_config.model.unk_token
    this.continuing_subword_prefix = tokenizer_config.model.continuing_subword_prefix
    this.end_of_word_suffix = tokenizer_config.model.end_of_word_suffix
    this.fuse_unk = tokenizer_config.model.fuse_unk
    this.byte_fallback = tokenizer_config.model.byte_fallback
  }


  public func token_to_id(token: String): Option<UInt32> {
    return this.vocab.get(token)
  }

  public func id_to_token(id: UInt32) {
    return this.vocab_r.get(id)
  }

  public func apply_chat_template(
    messages: ArrayList<Message>,
    add_generation_prompt!: Bool = false
  ): String {
    var result_str = ""
    if (messages.size == 0) {
      return result_str
    }
    match (messages[0].role) {
      case RoleType.System =>
        result_str += "<|im_start|>system\n${messages[0].content}<|im_end|>\n"
        messages.remove(at: 0)
      case _ => ()
    }
    if (messages.size % 2 != 1) {
      throw Exception("input times - assistant times must = 1")
    }
    for (i in 0..messages.size) {
      if (i % 2 == 0) {
        result_str += "<|im_start|>user\n${messages[i].content}<|im_end|>\n"
      } else {
        result_str += "<|im_start|>assistant\n${messages[i].content}<|im_end|>\n"
      }
    }
    if (add_generation_prompt) {
      result_str += "<|im_start|>assistant\n"
    }
    return result_str
  }

  public override func decode(token_ids: Array<UInt32>, skip_special_tokens!: Bool = false): String {
    let data_array = ArrayList<UInt8>()
    let special_set = this.special_id_set   //缓存局部变量 
    let rune_map = this.rune2byte           // 缓存局部变量 
 
    for (token_id in token_ids) {
        let token_str = this.id_to_token(token_id).getOrThrow() 
        if (skip_special_tokens && special_set.contains(token_str))  {
            continue 
        }
        for (rune in token_str.toRuneArray())  {
            data_array.add(rune_map[rune])   //假设rune_map是数组 
        }
    }
    return String.fromUtf8(data_array.toArray())
  }

  public override func encode(str: String, add_special_tokens!: Bool = false): Array<UInt32> {
    if (add_special_tokens) {
      // todo
    }
    var special_pattern: String = "("
    for ((idx, special_str) in this.special_id_set.iterator().enumerate()) {
      if (idx < this.special_id_set.size - 1) {
        special_pattern += special_str.replace("|", "\\|") + "|"
      } else {
        special_pattern += special_str.replace("|", "\\|") + ")"
      }
    }
    let special_re: Regex = Regex(special_pattern)
    let special_token_map = HashMap<Int64, UInt32>()
    let special_index_list = ArrayList<(Int64, Int64)>()
    let added_special_index = HashSet<Int64>()
    let special_array = special_re.matcher(str).findAll() ?? Array<MatchData>()
    var text = str
    for (special in special_array) {
      let special_token = special.matchString()
      println("special token: ${special_token}")
      special_token_map.add(
        special.matchPosition().start,
        this.vocab.get(special_token).getOrThrow()
      )
      special_index_list.add((
        special.matchPosition().start,  
        special.matchPosition().end  
      ))
      var empty_token = ""
      var temp_i = 0
      while(temp_i < special_token.size) {
        empty_token += " "
        temp_i++
      }
      text = text.replace(special_token, empty_token)
    }
    let text_re: Regex = Regex("('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\\s\\w]+|\\s+)")
    let match_array = text_re.matcher(text).findAll() ?? Array<MatchData>()
    var ids = ArrayList<UInt32>()
    var special_index = 0
    for (token in match_array) {
      let temp_start = token.matchPosition().start
      let temp_end = token.matchPosition().end
      var temp_si = special_index
      var is_special = false
      while (temp_si < special_index_list.size) {
        let (s_start, s_end) = special_index_list[temp_si]
        if (temp_end <= s_start) {
          break
        }
        if (temp_start >= s_start && temp_start < s_end) {
          if (!added_special_index.contains(temp_si)) {
            let special_token_id = special_token_map[s_start]
            ids.add(special_token_id)
            special_index = max(temp_si - 1, 0) 
            added_special_index.add(temp_si)
          }
          is_special = true
          break
        }
        temp_si ++
      }
      if (is_special) {
        continue
      }
      let token_str = token.matchString()
      var temp_str: String = ""
      var token_list = ArrayList<String>()
      for (b in token_str.toArray()) {
        let char = this.byte2rune[b]
        temp_str += char.toString()
        token_list.add(char.toString())
      }
      if (this.cache.contains(temp_str)) {
        let result_str_list = this.cache[temp_str]
        for (result_str in result_str_list) {
          ids.add(this.vocab[result_str])
        }
        continue
      }
      let str_list = this.bpe(token_list.toArray())
      this.cache.add(temp_str, str_list)
      for (temp_str in str_list) {
        let new_token = this.token_to_id(temp_str).getOrDefault({=>this.unk_token_id.getOrThrow()})
        ids.add(new_token)
      }
    }
    return ids.toArray()
  }

  private func get_pairs(tokens: Array<String>): Array<HashPairString> {
    var pairs: ArrayList<HashPairString> = ArrayList<HashPairString>()
    if (tokens.size > 1) {
      var previous: String = tokens[0]
      var next: String = " "
      var i: Int64 = 1
      while (i < tokens.size) {
        next = tokens[i]
        pairs.add(HashPairString((previous.toString(), next.toString())))
        previous = next
        i++
      }
    }
    return pairs.toArray()
  }

  private func bpe(token_list: Array<String>): Array<String> { 
    var word = token_list 
    if (word.isEmpty())  { 
        return word 
    } 
    var pairs: Array<HashPairString> = this.get_pairs(word)
    while (true) { 
        var merged: HashSet<Int32> = HashSet<Int32>()
        var min_score:UInt32 =UInt32.Max 
        var to_merge: Int = -1 
        var l: Int64 = 0
        while (l < pairs.size) {
          if (!merged.contains(Int32(l))) {
            let score = if (this.bpe_ranks_.contains(pairs[l])) {
              this.bpe_ranks_[pairs[l]]
            } else {
             UInt32.Max
            }
            if (score < min_score) {
              min_score = score
              to_merge = l
            }
          }
          l++
      } 
        if (to_merge == -1) { 
            break 
        } 
        let first = pairs[to_merge].first
        let second = pairs[to_merge].second
        var new_word: ArrayList<String> = ArrayList<String>() 
        var i = 0 
        while (i < word.size)  { 
            if (i < word.size  - 1 && word[i] == first && word[i + 1] == second) { 
                new_word.add(first  + second) 
                i += 2 
            } else { 
                new_word.add(word[i])  
                i += 1 
            } 
        } 
        word = new_word.toArray() 
        if (word.size  == 1) { 
            break 
        } else { 
            pairs = this.get_pairs(word)  
        } 
    } 
    return word 
} 
}