* Copyright (c) 2024 Huawei Technologies Co., Ltd.
* openFuyao is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
* http://license.coscl.org.cn/MulanPSL2
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
*/
package blockkey
import (
"crypto/sha256"
"encoding/binary"
"encoding/hex"
"fmt"
"sync"
"github.com/fxamacker/cbor/v2"
"github.com/go-logr/logr"
"github.com/zeebo/xxh3"
"gitcode.com/openFuyao/cache-indexer/pkg/apis"
)
type Builder interface {
BuildFromRequest(
modelCtx apis.ModelContext,
tokenIDs []int32,
) ([]apis.BlockHash, error)
}
type HashAlgo int
const (
HashSHA256CBOR HashAlgo = iota
HashXXHash3CBOR
)
const (
cborArrayLenThreeHeader = 0x83
uint64ByteWidth = 8
xxh3HashByteWidth = 16
)
type Config struct {
PythonHashSeed string
Algo HashAlgo
EmitInt64 bool
}
type builder struct {
log logr.Logger
cfg Config
encMode cbor.EncMode
noneOnce sync.Once
noneHash []byte
noneErr error
}
func New(log logr.Logger, cfg Config) (Builder, error) {
if cfg.PythonHashSeed == "" {
return nil, fmt.Errorf(
"blockkey: PYTHONHASHSEED is empty; required for stable NONE_HASH derivation",
)
}
em, err := cbor.CTAP2EncOptions().EncMode()
if err != nil {
return nil, fmt.Errorf("blockkey: cbor encoder: %w", err)
}
return &builder{log: log, cfg: cfg, encMode: em}, nil
}
func (b *builder) hash(payload []byte) []byte {
switch b.cfg.Algo {
case HashXXHash3CBOR:
u := xxh3.Hash128(payload)
out := make([]byte, xxh3HashByteWidth)
binary.BigEndian.PutUint64(out[:uint64ByteWidth], u.Hi)
binary.BigEndian.PutUint64(out[uint64ByteWidth:], u.Lo)
return out
default:
s := sha256.Sum256(payload)
return s[:]
}
}
func (b *builder) noneHashBytes() ([]byte, error) {
b.noneOnce.Do(func() {
seedCBOR, err := b.encMode.Marshal(b.cfg.PythonHashSeed)
if err != nil {
b.noneErr = fmt.Errorf("encode PYTHONHASHSEED: %w", err)
return
}
b.noneHash = b.hash(seedCBOR)
})
return b.noneHash, b.noneErr
}
func (b *builder) BuildFromRequest(
modelCtx apis.ModelContext,
tokenIDs []int32,
) ([]apis.BlockHash, error) {
bs := int(modelCtx.BlockSize)
if bs <= 0 {
return nil, fmt.Errorf("blockkey: block_size must be > 0 (got %d)", bs)
}
full := len(tokenIDs) / bs
if full == 0 {
return nil, nil
}
none, err := b.noneHashBytes()
if err != nil {
return nil, err
}
extraFirst, err := b.encodeExtraKeys(modelCtx)
if err != nil {
return nil, err
}
extraRest, err := b.encMode.Marshal(nil)
if err != nil {
return nil, err
}
out := make([]apis.BlockHash, full)
parent := none
for i := 0; i < full; i++ {
toks := tokenIDs[i*bs : (i+1)*bs]
extra := extraRest
if i == 0 {
extra = extraFirst
}
raw, err := b.hashOneBlock(parent, toks, extra)
if err != nil {
return nil, err
}
out[i] = b.encodeBlockHash(raw)
parent = raw
}
return out, nil
}
func (b *builder) encodeExtraKeys(mc apis.ModelContext) ([]byte, error) {
if mc.CacheSalt == "" {
return b.encMode.Marshal(nil)
}
return b.encMode.Marshal([]any{mc.CacheSalt})
}
func (b *builder) hashOneBlock(parent []byte, tokens []int32, extraCBOR []byte) ([]byte, error) {
tokensAny := make([]any, len(tokens))
for i, t := range tokens {
tokensAny[i] = int64(t)
}
tokensCBOR, err := b.encMode.Marshal(tokensAny)
if err != nil {
return nil, fmt.Errorf("encode tokens: %w", err)
}
parentCBOR, err := b.encMode.Marshal(parent)
if err != nil {
return nil, fmt.Errorf("encode parent: %w", err)
}
buf := make([]byte, 0, 1+len(parentCBOR)+len(tokensCBOR)+len(extraCBOR))
buf = append(buf, cborArrayLenThreeHeader)
buf = append(buf, parentCBOR...)
buf = append(buf, tokensCBOR...)
buf = append(buf, extraCBOR...)
return b.hash(buf), nil
}
func (b *builder) encodeBlockHash(raw []byte) apis.BlockHash {
if !b.cfg.EmitInt64 {
return apis.BlockHash(hex.EncodeToString(raw))
}
if len(raw) < uint64ByteWidth {
var buf [uint64ByteWidth]byte
copy(buf[uint64ByteWidth-len(raw):], raw)
return apis.BlockHash(hex.EncodeToString(buf[:]))
}
low := raw[len(raw)-uint64ByteWidth:]
return apis.BlockHash(hex.EncodeToString(low))
}