02f94c82创建于 2025年5月14日历史提交
package tokenizer
import encoding.json.stream.JsonReader
internal import std.collection.{ArrayList, HashMap, HashSet}
import std.fs.{File, Path, OpenMode, exists}
import std.io.{readToEnd, ByteBuffer}
import std.regex.*
import std.math.*

public enum TokenizerType {
  SENTENCEPIECE |
  TIKTOIKEN |
  BERT |
  HUGGINGFACE
}

public class HashPairString <: Hashable & Equatable<HashPairString> {
  public var first: String
  public var second: String
  public init(p: (String, String)) {
    this.first = p[0]
    this.second = p[1]
  }

  public func hashCode(): Int64 {
    let hash1: Int64 = this.first.hashCode()
    let hash2: Int64 = this.second.hashCode()
    return hash1 ^ hash2
  }

  public operator func == (other: HashPairString): Bool {
    return this.first == other.first && this.second == other.second
  }

  public operator func != (other: HashPairString): Bool {
    return !(this == other)
  }
}

public class HashPairUInt32 <: Hashable & Equatable<HashPairUInt32> & ToString{
  public var first: UInt32
  public var second: UInt32
  public init(p: (UInt32, UInt32)) {
    this.first = p[0]
    this.second = p[1]
  }

  public func toString(): String {
    return "(${this.first}, ${this.second})"
  }

  public func hashCode(): Int64 {
    let hash1: Int64 = this.first.hashCode()
    let hash2: Int64 = this.second.hashCode()
    return hash1 ^ hash2
  }

  public operator func == (other: HashPairUInt32): Bool {
    return this.first == other.first && this.second == other.second
  }

  public operator func != (other: HashPairUInt32): Bool {
    return !(this == other)
  }
}

public open class Tokenizer {
  protected var special_tokens_: ArrayList<UInt32> = ArrayList<UInt32>()
  protected var stop_tokens_: ArrayList<UInt32> = ArrayList<UInt32>()
  protected var prefix_tokens_: ArrayList<UInt32> = ArrayList<UInt32>()

  public static const MAGIC_NUMBER:UInt32 = 430
  public static func createTokenizer(file_path: String, tokenizer_type: TokenizerType): Tokenizer {
    match (tokenizer_type) {
      case TokenizerType.HUGGINGFACE =>
        let tokenizer = HuggingfaceTokenizer()
        tokenizer.load_vocab(file_path: file_path)
        println("tokenizer load ok")
        return tokenizer
      case other => throw Exception("unsupport tokenizer type")
    }
  }

  public func is_stop(token:UInt32): Bool {
    return stop_tokens_.contains(token)
  }

  public func is_special(token:UInt32): Bool {
    return special_tokens_.contains(token)
  }

  public open func encode(_: String, add_special_tokens!: Bool): Array<UInt32> {
    if (add_special_tokens) {
      // todo
    }
    return Array<UInt32>()
  }

  public open func decode(_: Array<UInt32>, skip_special_tokens!: Bool): String {
    if (skip_special_tokens) {
      // todo
    }
    return ""
  }

  protected open func load_special(_: String) {
  }

  public open func load_vocab(file_path!: String, buffer!:Array<UInt8>) {
    if (!exists(file_path) && buffer.size == 0) {
      throw Exception("${file_path} not exists and buffer is empty")
    }
  }
}

public class Tiktoken <: Tokenizer {
  protected var vocab_: HashMap<String,UInt32> = HashMap<String,UInt32>()
  protected var decoder_: ArrayList<String> = ArrayList<String>()

  public override func decode(_: Array<UInt32>, skip_special_tokens!: Bool = false): String {
    if (skip_special_tokens) {
    }
    return ""
  }

  public override func load_vocab(file_path!: String = "", buffer!: Array<UInt8> = Array<UInt8>()){
    if (!exists(file_path) && buffer.size == 0) {
      throw Exception("${file_path} not exists and buffer is empty")
    }
  }

  public override func encode(str: String, add_special_tokens!: Bool = false): Array<UInt32> {
    if (add_special_tokens) {
      // todo
    }
    var token_ids: ArrayList<UInt32> = ArrayList<UInt32>()
    if (str.isEmpty()) {
      return token_ids.toArray()
    }
    var i: Int64 = 0
    while (i < str.size) {
      var len = str.size - i
      var longest_match_len: Int64 = 0
      var longest_match: String = ""
      while(len > 0) {
        let token = str[i..i + len]
        if (this.vocab_.contains(token) && len > longest_match_len) {
          longest_match_len = len
          longest_match = token
        }
        len--
      }
      if (!longest_match.isEmpty()) {
        token_ids.add(this.vocab_[longest_match])
        i += longest_match_len
      } else {
        eprintln("Error: No encoding found for the sequence starting at position ${i}")
      }
    }
    return token_ids.toArray()
  }
}