* Copyright (c) 2024 Huawei Technologies Co., Ltd.
* openFuyao is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
* http://license.coscl.org.cn/MulanPSL2
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
*/
package server
import (
"encoding/json"
"errors"
"io"
"net/http"
"github.com/go-logr/logr"
"gitcode.com/openFuyao/cache-indexer/pkg/apis"
"gitcode.com/openFuyao/cache-indexer/pkg/config"
"gitcode.com/openFuyao/cache-indexer/pkg/score"
)
type hitRateRequest struct {
ServerIP []string `json:"server_ip"`
Body hitRateRequestBody `json:"body"`
}
type hitRateRequestBody struct {
TokenIDs []int32 `json:"token_ids"`
CacheSalt string `json:"cache_salt"`
BlockSize int64 `json:"block_size"`
}
type hitRateResponse struct {
ServerScoreList []serverScoreDTO `json:"server_score_list"`
Message string `json:"message"`
Status int `json:"status"`
}
type serverScoreDTO struct {
ServerIP string `json:"server_ip"`
L1HitRatio float64 `json:"l1_hit_ratio"`
L3HitRatio float64 `json:"l3_hit_ratio"`
}
type errorResponse struct {
Message string `json:"message"`
Status int `json:"status"`
}
type healthResponse struct {
Status string `json:"status"`
}
func newMux(log logr.Logger, scoreSvc score.Service, cfg config.HTTPConfig) *http.ServeMux {
mux := http.NewServeMux()
mux.HandleFunc("/healthz", func(w http.ResponseWriter, _ *http.Request) {
writeJSON(log, w, http.StatusOK, healthResponse{Status: "ok"})
})
mux.HandleFunc("/readyz", func(w http.ResponseWriter, _ *http.Request) {
writeJSON(log, w, http.StatusOK, healthResponse{Status: "ok"})
})
mux.HandleFunc("/kv-cache/hit-rate", hitRateHandler(log.WithName("hit-rate"), scoreSvc, cfg.MaxHitRateBodyBytes))
return mux
}
func hitRateHandler(log logr.Logger, scoreSvc score.Service, maxBodyBytes int64) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
handleHitRate(log, scoreSvc, maxBodyBytes, w, r)
}
}
func handleHitRate(
log logr.Logger,
scoreSvc score.Service,
maxBodyBytes int64,
w http.ResponseWriter,
r *http.Request,
) {
if !allowPostMethod(log, w, r) {
return
}
log.V(1).Info("hit-rate request received", "remoteAddr", r.RemoteAddr)
req, ok := decodeHitRateRequest(log, w, r, maxBodyBytes)
if !ok || !validateHitRateRequest(log, w, req) {
return
}
scores, ok := computeHitRateScores(log, scoreSvc, w, r, req)
if !ok {
return
}
respBody := buildHitRateResponse(scores)
logHitRateCompletion(log, req, respBody)
writeJSON(log, w, http.StatusOK, respBody)
}
func allowPostMethod(log logr.Logger, w http.ResponseWriter, r *http.Request) bool {
if r.Method == http.MethodPost {
return true
}
writeJSON(log, w, http.StatusMethodNotAllowed, errorResponse{
Message: "method not allowed",
Status: http.StatusMethodNotAllowed,
})
return false
}
func decodeHitRateRequest(
log logr.Logger,
w http.ResponseWriter,
r *http.Request,
maxBodyBytes int64,
) (hitRateRequest, bool) {
var req hitRateRequest
dec := json.NewDecoder(http.MaxBytesReader(w, r.Body, maxBodyBytes))
dec.DisallowUnknownFields()
if err := dec.Decode(&req); err != nil {
var maxErr *http.MaxBytesError
switch {
case errors.As(err, &maxErr):
log.Info("hit-rate request rejected: body too large", "error", err.Error())
writeJSON(log, w, http.StatusRequestEntityTooLarge, errorResponse{
Message: "request body too large",
Status: http.StatusRequestEntityTooLarge,
})
case errors.Is(err, io.EOF):
log.Info("hit-rate request rejected: empty body", "error", err.Error())
writeJSON(log, w, http.StatusBadRequest, errorResponse{
Message: "empty request body",
Status: http.StatusBadRequest,
})
default:
log.Info("hit-rate request rejected: invalid JSON", "error", err.Error())
writeJSON(log, w, http.StatusBadRequest, errorResponse{
Message: "invalid JSON: " + err.Error(),
Status: http.StatusBadRequest,
})
}
return hitRateRequest{}, false
}
return req, true
}
func validateHitRateRequest(log logr.Logger, w http.ResponseWriter, req hitRateRequest) bool {
if len(req.ServerIP) == 0 {
log.Info("hit-rate request rejected: server_ip is required")
writeJSON(log, w, http.StatusBadRequest, errorResponse{
Message: "server_ip is required",
Status: http.StatusBadRequest,
})
return false
}
if req.Body.BlockSize > 0 {
return true
}
log.Info("hit-rate request rejected: body.block_size must be > 0",
"blockSize", req.Body.BlockSize)
writeJSON(log, w, http.StatusBadRequest, errorResponse{
Message: "body.block_size must be > 0",
Status: http.StatusBadRequest,
})
return false
}
func computeHitRateScores(
log logr.Logger,
scoreSvc score.Service,
w http.ResponseWriter,
r *http.Request,
req hitRateRequest,
) ([]apis.ServerScore, bool) {
scores, err := scoreSvc.Compute(r.Context(), requestModelContext(req.Body), req.ServerIP, req.Body.TokenIDs)
if err == nil {
return scores, true
}
log.Error(err, "hit-rate request failed: compute",
"servers", len(req.ServerIP), "tokenIDs", len(req.Body.TokenIDs))
writeJSON(log, w, http.StatusInternalServerError, errorResponse{
Message: "score evaluation failed",
Status: http.StatusInternalServerError,
})
return nil, false
}
func requestModelContext(body hitRateRequestBody) apis.ModelContext {
return apis.ModelContext{
BlockSize: body.BlockSize,
CacheSalt: body.CacheSalt,
}
}
func buildHitRateResponse(scores []apis.ServerScore) hitRateResponse {
dtos := make([]serverScoreDTO, len(scores))
for i, s := range scores {
dtos[i] = serverScoreDTO{
ServerIP: s.ServerIP,
L1HitRatio: s.L1HitRatio,
L3HitRatio: s.L3HitRatio,
}
}
return hitRateResponse{ServerScoreList: dtos}
}
func logHitRateCompletion(log logr.Logger, req hitRateRequest, respBody hitRateResponse) {
respJSON, err := json.Marshal(respBody)
if err != nil {
log.Error(err, "hit-rate response marshal failed for logging",
"servers", len(req.ServerIP), "tokenIDs", len(req.Body.TokenIDs))
return
}
log.Info("hit-rate request completed",
"servers", len(req.ServerIP),
"tokenIDs", len(req.Body.TokenIDs),
"responseBody", string(respJSON))
}
func writeJSON(log logr.Logger, w http.ResponseWriter, code int, body any) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(code)
if err := json.NewEncoder(w).Encode(body); err != nil {
log.Error(err, "write JSON response failed", "status", code)
}
}