* Copyright (c) 2024 Huawei Technologies Co., Ltd.
* openFuyao is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
* http://license.coscl.org.cn/MulanPSL2
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
*/
package l3
import (
"bytes"
"context"
"encoding/json"
"errors"
"io"
"net/http"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/go-logr/logr"
"gitcode.com/openFuyao/cache-indexer/pkg/apis"
indexl3 "gitcode.com/openFuyao/cache-indexer/pkg/index/l3"
)
const (
testMooncakeKeyA = "m@pcp0@AAAA"
testMooncakeKeyB = "m@pcp0@BBBB"
testMooncakeKeyC = "m@pcp0@CCCC"
testMooncakeEP1 = "ep-1:8000"
testMooncakeEP2 = "ep-2:8000"
testObjectNotFoundBody = "OBJECT_NOT_FOUND"
testQueryEscapeRaw = "host@a:b?c d+#%\t\r\n&"
testQueryEscapeWant = "host@a:b%3fc%20d%2b%23%25%09%0d%0a%26"
)
type fakeDoer struct {
mu sync.Mutex
allKeysSeq [][]string
batchValues map[string][]struct {
TransportEndpoint string `json:"transport_endpoint_"`
Protocol string `json:"protocol_"`
}
allKeysCalls atomic.Int32
batchKeysCalls atomic.Int32
batchKeysSeen [][]string
}
func (f *fakeDoer) Do(req *http.Request) (*http.Response, error) {
switch req.URL.Path {
case "/get_all_keys":
f.allKeysCalls.Add(1)
f.mu.Lock()
var keys []string
if len(f.allKeysSeq) > 0 {
keys = f.allKeysSeq[0]
f.allKeysSeq = f.allKeysSeq[1:]
}
f.mu.Unlock()
body := []byte(strings.Join(keys, "\n") + "\n")
return &http.Response{StatusCode: 200, Body: io.NopCloser(bytes.NewReader(body))}, nil
case "/query_key":
f.batchKeysCalls.Add(1)
asked := req.URL.Query().Get("key")
f.mu.Lock()
f.batchKeysSeen = append(f.batchKeysSeen, []string{asked})
vs, ok := f.batchValues[asked]
f.mu.Unlock()
if !ok {
return &http.Response{StatusCode: 404, Body: io.NopCloser(bytes.NewReader(nil))}, nil
}
var b bytes.Buffer
for _, v := range vs {
line, _ := json.Marshal(struct {
TransportEndpoint string `json:"transport_endpoint_"`
Protocol string `json:"protocol_"`
}{TransportEndpoint: v.TransportEndpoint, Protocol: v.Protocol})
b.Write(line)
b.WriteByte('\n')
}
return &http.Response{StatusCode: 200, Body: io.NopCloser(&b)}, nil
}
return &http.Response{StatusCode: 404, Body: io.NopCloser(bytes.NewReader(nil))}, nil
}
type responseSpec struct {
status int
body string
err error
}
type scriptedDoer struct {
mu sync.Mutex
byPath map[string][]responseSpec
lastURL string
}
func (s *scriptedDoer) Do(req *http.Request) (*http.Response, error) {
s.mu.Lock()
defer s.mu.Unlock()
s.lastURL = req.URL.String()
specs := s.byPath[req.URL.Path]
if len(specs) == 0 {
return nil, errors.New("unexpected request: " + req.URL.Path)
}
spec := specs[0]
s.byPath[req.URL.Path] = specs[1:]
if spec.err != nil {
return nil, spec.err
}
return &http.Response{
StatusCode: spec.status,
Body: io.NopCloser(strings.NewReader(spec.body)),
}, nil
}
func TestMooncakeKeyToBlockHash(t *testing.T) {
cases := []struct{ in, want string }{
{"Qwen3-8B@pcp0@dcp0@head_or_tp_rank:0@pp_rank:0@DEADBEEFDEADBEEF", "deadbeefdeadbeef"},
{"deadbeef", "deadbeef"},
{"", ""},
{"only@", ""},
}
for _, c := range cases {
got := string(mooncakeKeyToBlockHash(c.in))
if got != c.want {
t.Errorf("in=%q got=%q want=%q", c.in, got, c.want)
}
}
}
func TestMooncakeKeyToBlockHash_EmitInt64Truncation(t *testing.T) {
const fullHex = "198bcba7945a909d278ff8b13c418bea79aecf30da51c0663f061826a245be2e"
const wantInt = "3f061826a245be2e"
u := &updater{cfg: PollerConfig{EmitInt64: true}.withDefaults()}
got := string(u.mooncakeKeyToBlockHash("Qwen3-8B@pcp0@dcp0@head_or_tp_rank:0@pp_rank:0@" + fullHex))
if got != wantInt {
t.Errorf("EmitInt64=true got=%q want=%q", got, wantInt)
}
u2 := &updater{cfg: PollerConfig{EmitInt64: false}.withDefaults()}
got2 := string(u2.mooncakeKeyToBlockHash("Qwen3-8B@pcp0@" + fullHex))
if got2 != fullHex {
t.Errorf("EmitInt64=false got=%q want=%q", got2, fullHex)
}
u3 := &updater{cfg: PollerConfig{EmitInt64: true}.withDefaults()}
if got3 := string(u3.mooncakeKeyToBlockHash("m@pcp0@DEAD")); got3 != "dead" {
t.Errorf("short suffix got=%q want=%q", got3, "dead")
}
}
func TestPollOnce_RefreshAndRemoveDiff(t *testing.T) {
idx := indexl3.New(logr.Discard())
doer := &fakeDoer{
allKeysSeq: [][]string{
{testMooncakeKeyA, testMooncakeKeyB},
{testMooncakeKeyB, testMooncakeKeyC},
},
batchValues: map[string][]struct {
TransportEndpoint string `json:"transport_endpoint_"`
Protocol string `json:"protocol_"`
}{
testMooncakeKeyA: {{TransportEndpoint: testMooncakeEP1}},
testMooncakeKeyB: {{TransportEndpoint: testMooncakeEP1}},
testMooncakeKeyC: {{TransportEndpoint: testMooncakeEP2}},
},
}
u := &updater{
log: logr.Discard(), cfg: PollerConfig{}.withDefaults(),
indexer: idx, client: doer,
keyOwnership: map[string][]apis.TransportEndpoint{},
}
prev := map[string]struct{}{}
ctx := context.Background()
prev = u.pollOnce(ctx, "http://x", prev)
res, _ := idx.MatchPrefix(apis.ModelContext{}, "ep-1:8000",
[]apis.BlockHash{"aaaa", "bbbb"})
if res.MatchedBlocks != 2 {
t.Fatalf("round1 ep-1 matched=%d want 2", res.MatchedBlocks)
}
prev = u.pollOnce(ctx, "http://x", prev)
res, _ = idx.MatchPrefix(apis.ModelContext{}, "ep-1:8000",
[]apis.BlockHash{"aaaa"})
if res.MatchedBlocks != 0 {
t.Errorf("AAAA should be removed, got %d", res.MatchedBlocks)
}
res, _ = idx.MatchPrefix(apis.ModelContext{}, "ep-1:8000",
[]apis.BlockHash{"bbbb"})
if res.MatchedBlocks != 1 {
t.Errorf("BBBB should remain, got %d", res.MatchedBlocks)
}
res, _ = idx.MatchPrefix(apis.ModelContext{}, "ep-2:8000",
[]apis.BlockHash{"cccc"})
if res.MatchedBlocks != 1 {
t.Errorf("CCCC should be on ep-2, got %d", res.MatchedBlocks)
}
if got := doer.batchKeysCalls.Load(); got != 3 {
t.Errorf("query_key calls = %d want 3", got)
}
if n := len(doer.batchKeysSeen); n == 0 || len(doer.batchKeysSeen[n-1]) != 1 ||
doer.batchKeysSeen[n-1][0] != testMooncakeKeyC {
t.Errorf("last refresh entry = %v want [%s]", doer.batchKeysSeen, testMooncakeKeyC)
}
}
func TestPollOnce_SkipsRowsWithoutTransportEndpoint(t *testing.T) {
idx := indexl3.New(logr.Discard())
doer := &fakeDoer{
allKeysSeq: [][]string{{"m@AAAA"}},
batchValues: map[string][]struct {
TransportEndpoint string `json:"transport_endpoint_"`
Protocol string `json:"protocol_"`
}{
"m@AAAA": {
{TransportEndpoint: ""},
{TransportEndpoint: "ep:1"},
},
},
}
u := &updater{log: logr.Discard(), cfg: PollerConfig{}.withDefaults(),
indexer: idx, client: doer, keyOwnership: map[string][]apis.TransportEndpoint{}}
u.pollOnce(context.Background(), "http://x", map[string]struct{}{})
res, _ := idx.MatchPrefix(apis.ModelContext{}, "ep:1", []apis.BlockHash{"aaaa"})
if res.MatchedBlocks != 1 {
t.Errorf("ep:1 should hold AAAA, got %d", res.MatchedBlocks)
}
}
func TestPollOnce_NilPrevInitializesCarryOver(t *testing.T) {
idx := indexl3.New(logr.Discard())
doer := &fakeDoer{
allKeysSeq: [][]string{{"m@AAAA"}},
batchValues: map[string][]struct {
TransportEndpoint string `json:"transport_endpoint_"`
Protocol string `json:"protocol_"`
}{
"m@AAAA": {{TransportEndpoint: "ep:1"}},
},
}
u := &updater{log: logr.Discard(), cfg: PollerConfig{}.withDefaults(),
indexer: idx, client: doer, keyOwnership: map[string][]apis.TransportEndpoint{}}
prev := u.pollOnce(context.Background(), "http://x", nil)
if prev == nil {
t.Fatal("expected pollOnce to initialize prev")
}
if _, ok := prev["m@AAAA"]; !ok {
t.Fatalf("expected carry-over to contain refreshed key, got %v", prev)
}
res, _ := idx.MatchPrefix(apis.ModelContext{}, "ep:1", []apis.BlockHash{"aaaa"})
if res.MatchedBlocks != 1 {
t.Fatalf("expected refreshed key to be ingested, got %d", res.MatchedBlocks)
}
}
func TestEscapeMooncakeKeyForQuery(t *testing.T) {
got := escapeMooncakeKeyForQuery(testQueryEscapeRaw)
if got != testQueryEscapeWant {
t.Fatalf("got=%q want=%q", got, testQueryEscapeWant)
}
}
func TestFetchQueryKey_ObjectNotFound200(t *testing.T) {
doer := &scriptedDoer{
byPath: map[string][]responseSpec{
"/query_key": {{status: http.StatusOK, body: testObjectNotFoundBody}},
},
}
u := &updater{log: logr.Discard(), cfg: PollerConfig{}.withDefaults(), client: doer}
rows, err := u.fetchQueryKey(context.Background(), "http://x", testMooncakeKeyA)
if err != nil {
t.Fatalf("fetchQueryKey: %v", err)
}
if rows != nil {
t.Fatalf("rows=%v want nil", rows)
}
if !strings.Contains(doer.lastURL, escapeMooncakeKeyForQuery(testMooncakeKeyA)) {
t.Fatalf("lastURL=%q missing escaped key", doer.lastURL)
}
}
func TestPollOnce_FetchAllKeysErrorPreservesPrev(t *testing.T) {
prev := map[string]struct{}{testMooncakeKeyA: {}}
doer := &scriptedDoer{
byPath: map[string][]responseSpec{
"/get_all_keys": {{err: errors.New("boom")}},
},
}
u := &updater{
log: logr.Discard(),
cfg: PollerConfig{}.withDefaults(),
indexer: indexl3.New(logr.Discard()),
client: doer,
keyOwnership: map[string][]apis.TransportEndpoint{testMooncakeKeyA: {testMooncakeEP1}},
}
got := u.pollOnce(context.Background(), "http://x", prev)
if len(got) != 1 {
t.Fatalf("prev mutated: %v", got)
}
if _, ok := got[testMooncakeKeyA]; !ok {
t.Fatalf("expected prev to keep %q", testMooncakeKeyA)
}
if len(u.keyOwnership[testMooncakeKeyA]) != 1 {
t.Fatalf("keyOwnership mutated: %+v", u.keyOwnership)
}
}
func TestUpdateTarget_StopsPreviousLoop(t *testing.T) {
idx := indexl3.New(logr.Discard())
doer := &fakeDoer{allKeysSeq: nil, batchValues: nil}
u := NewWithConfig(logr.Discard(), idx, PollerConfig{PollInterval: 50 * time.Millisecond}).(*updater)
u.client = doer
if err := u.UpdateTarget(&MooncakePollTarget{PodIP: "1.2.3.4", HTTPPort: 8000}); err != nil {
t.Fatal(err)
}
time.Sleep(120 * time.Millisecond)
first := doer.allKeysCalls.Load()
if first == 0 {
t.Fatal("expected at least one /get_all_keys call before swap")
}
if err := u.UpdateTarget(nil); err != nil {
t.Fatal(err)
}
time.Sleep(120 * time.Millisecond)
after := doer.allKeysCalls.Load()
if after-first > 1 {
t.Errorf("loop kept polling after nil target: first=%d after=%d", first, after)
}
u.Stop()
if err := u.UpdateTarget(&MooncakePollTarget{PodIP: "1.2.3.4", HTTPPort: 8000}); err == nil {
t.Error("UpdateTarget after Stop should error")
}
}
func TestUpdateTarget_InvalidTargetIsNoOp(t *testing.T) {
idx := indexl3.New(logr.Discard())
doer := &fakeDoer{}
u := NewWithConfig(logr.Discard(), idx, PollerConfig{PollInterval: 50 * time.Millisecond}).(*updater)
u.client = doer
if err := u.UpdateTarget(&MooncakePollTarget{PodIP: "", HTTPPort: 0}); err != nil {
t.Fatal(err)
}
time.Sleep(80 * time.Millisecond)
if doer.allKeysCalls.Load() != 0 {
t.Errorf("invalid target must not trigger HTTP calls")
}
u.Stop()
}