* Copyright (c) 2024 Huawei Technologies Co., Ltd.
* openFuyao is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
* http://license.coscl.org.cn/MulanPSL2
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
*/
package lifecycle
import (
"bytes"
"context"
"encoding/json"
"fmt"
"maps"
"reflect"
"time"
"github.com/go-logr/logr"
"sigs.k8s.io/controller-runtime/pkg/log"
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/common/observability/logging"
fwkdl "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/datalayer"
fwkplugin "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/plugin"
requestcontrol "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/requestcontrol"
fwkscheduling "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/scheduling"
latencyattr "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/plugins/datalayer/attribute/latency"
npuattr "hermes-router/pkg/epp/framework/plugins/datalayer/attribute/npu"
prerequest "hermes-router/pkg/epp/framework/plugins/requestcontrol/prerequest"
internalinflight "hermes-router/pkg/epp/internal/inflight"
"hermes-router/pkg/epp/internal/pdgroup"
"hermes-router/pkg/epp/internal/score"
"hermes-router/pkg/epp/internal/utils"
)
const PluginType = "request-lifecycle-tracker"
type clock func() time.Time
type pluginParameters struct {
StoreName string `json:"storeName"`
Persistence persistenceParameters `json:"persistence"`
}
type persistenceParameters struct {
Enabled bool `json:"enabled"`
FlushThreshold int `json:"flushThreshold"`
FlushInterval time.Duration
OutputPath string `json:"outputPath"`
}
type rawPluginParameters struct {
StoreName string `json:"storeName"`
Persistence rawPersistenceParameters `json:"persistence"`
}
type rawPersistenceParameters struct {
Enabled bool `json:"enabled"`
FlushThreshold int `json:"flushThreshold"`
FlushInterval string `json:"flushInterval"`
OutputPath string `json:"outputPath"`
}
type RequestLifecycleTracker struct {
typedName fwkplugin.TypedName
store *internalinflight.Store
now clock
}
var _ requestcontrol.PreRequest = (*RequestLifecycleTracker)(nil)
var _ requestcontrol.ResponseBody = (*RequestLifecycleTracker)(nil)
var _ fwkdl.EndpointExtractor = (*RequestLifecycleTracker)(nil)
func Factory(name string, raw json.RawMessage, handle fwkplugin.Handle) (fwkplugin.Plugin, error) {
params, err := parseParameters(raw)
if err != nil {
return nil, fmt.Errorf("request lifecycle tracker %q: %w", name, err)
}
store := internalinflight.StoreForNameWithPersistence(params.StoreName, internalinflight.PersistenceConfig{
Enabled: params.Persistence.Enabled,
FlushThreshold: params.Persistence.FlushThreshold,
FlushInterval: params.Persistence.FlushInterval,
OutputPath: params.Persistence.OutputPath,
})
store.ConfigureLogger(storeLogger(handle, name, params.StoreName))
registerShutdownClose(handle, store)
return NewRequestLifecycleTracker(name, store), nil
}
func NewRequestLifecycleTracker(name string, store *internalinflight.Store) *RequestLifecycleTracker {
return &RequestLifecycleTracker{
typedName: fwkplugin.TypedName{Type: PluginType, Name: name},
store: store,
now: time.Now,
}
}
func (t *RequestLifecycleTracker) TypedName() fwkplugin.TypedName {
return t.typedName
}
func (t *RequestLifecycleTracker) ExpectedInputType() reflect.Type {
return fwkdl.EndpointEventReflectType
}
func (t *RequestLifecycleTracker) Extract(_ context.Context, _ any, _ fwkdl.Endpoint) error {
return nil
}
func (t *RequestLifecycleTracker) PreRequest(
ctx context.Context,
request *fwkscheduling.LLMRequest,
schedulingResult *fwkscheduling.SchedulingResult,
) {
logger := log.FromContext(ctx).WithName("RequestLifecycleTracker.PreRequest")
if t == nil || t.store == nil || request == nil || request.RequestId == "" {
logger.V(logutil.DEBUG).Info("Skipping request lifecycle tracking due to missing request or request ID")
return
}
route := prerequest.ExtractPrimaryRoute(ctx, schedulingResult)
selected := selectedTarget(schedulingResult)
if selected == nil || selected.GetMetadata() == nil {
logger.V(logutil.DEBUG).Info("Skipping request lifecycle tracking because no selected endpoint is available")
return
}
trackedEndpoints := trackedRouteEndpoints(route)
predictionInputs, missingFeatureReasons := selectedPredictionInputs(selected)
inputTokenLength := requestInputTokens(request)
cacheScore, prediction, predictionMode := selectedRoutingSnapshot(selected)
t.store.StartRequest(internalinflight.RequestStart{
RequestID: request.RequestId,
Model: request.TargetModel,
InputTokenLength: inputTokenLength,
SelectedEndpoint: selected.GetMetadata(),
TrackedEndpoints: trackedEndpoints,
DispatchInflight: t.store.Snapshot(selected.GetMetadata()),
PrefixCacheScore: cacheScore,
Prediction: prediction,
PredictionMode: predictionMode,
SelectedRoute: selectedRouteSnapshot(t.store, selected),
PredictionInputs: predictionInputs,
MissingFeatureReasons: missingFeatureReasons,
StartedAt: t.now(),
})
}
func (t *RequestLifecycleTracker) ResponseBody(
ctx context.Context,
request *fwkscheduling.LLMRequest,
response *requestcontrol.Response,
targetEndpoint *fwkdl.EndpointMetadata,
) {
logger := log.FromContext(ctx).WithName("RequestLifecycleTracker.ResponseBody")
if !canTrackResponse(t, request, response) {
logger.V(logutil.DEBUG).Info("Skipping response lifecycle tracking due to missing request, request ID, or response")
return
}
t.store.MarkFirstChunk(request.RequestId, t.now())
if !response.EndOfStream {
return
}
t.store.CompleteRequest(internalinflight.RequestCompletion{
RequestID: request.RequestId,
CompletedAt: t.now(),
OutputLength: outputLength(response),
SelectedEndpoint: targetEndpoint,
})
}
func (t *RequestLifecycleTracker) ExtractEndpoint(_ context.Context, event fwkdl.EndpointEvent) error {
if t == nil || t.store == nil {
return nil
}
t.store.ApplyEndpointEvent(event, t.now())
return nil
}
func parseParameters(raw json.RawMessage) (pluginParameters, error) {
decoded := rawPluginParameters{StoreName: internalinflight.DefaultStoreName}
if len(raw) == 0 {
return pluginParameters{StoreName: internalinflight.DefaultStoreName}, nil
}
decoder := json.NewDecoder(bytes.NewReader(raw))
decoder.DisallowUnknownFields()
if err := decoder.Decode(&decoded); err != nil {
return pluginParameters{}, fmt.Errorf("invalid parameters: %w", err)
}
params := pluginParameters{
StoreName: decoded.StoreName,
Persistence: persistenceParameters{
Enabled: decoded.Persistence.Enabled,
FlushThreshold: decoded.Persistence.FlushThreshold,
OutputPath: decoded.Persistence.OutputPath,
},
}
if params.StoreName == "" {
params.StoreName = internalinflight.DefaultStoreName
}
if params.Persistence.FlushThreshold < 0 {
return pluginParameters{}, fmt.Errorf("invalid parameters: persistence.flushThreshold must be >= 0")
}
if decoded.Persistence.FlushInterval != "" {
flushInterval, err := time.ParseDuration(decoded.Persistence.FlushInterval)
if err != nil {
return pluginParameters{}, fmt.Errorf(
"invalid parameters: persistence.flushInterval must be a valid duration: %w",
err,
)
}
if flushInterval < 0 {
return pluginParameters{}, fmt.Errorf("invalid parameters: persistence.flushInterval must be >= 0")
}
params.Persistence.FlushInterval = flushInterval
}
if params.Persistence.Enabled && params.Persistence.OutputPath == "" {
return pluginParameters{}, fmt.Errorf(
"invalid parameters: persistence.outputPath is required when persistence is enabled",
)
}
return params, nil
}
func requestInputTokens(request *fwkscheduling.LLMRequest) int {
if request == nil {
return 0
}
if request.TokenizedPrompt != nil && len(request.TokenizedPrompt.TokenIDs) > 0 {
return len(request.TokenizedPrompt.TokenIDs)
}
length, err := utils.RequestLengthOf(request)
if err != nil {
return 0
}
return length
}
func selectedTarget(result *fwkscheduling.SchedulingResult) fwkscheduling.Endpoint {
targets := primaryProfileTargets(result)
if len(targets) == 0 {
return nil
}
leader := targets[0]
if leader == nil || leader.GetMetadata() == nil {
return nil
}
return leader
}
func primaryProfileTargets(result *fwkscheduling.SchedulingResult) []fwkscheduling.Endpoint {
if result == nil || result.PrimaryProfileName == "" || result.ProfileResults == nil {
return nil
}
runResult := result.ProfileResults[result.PrimaryProfileName]
if runResult == nil || len(runResult.TargetEndpoints) == 0 {
return nil
}
return runResult.TargetEndpoints
}
func hasCompletePDRoute(route *prerequest.Route) bool {
return route != nil && route.Leader != nil && route.Prefill != nil && route.Decode != nil
}
func readLatencyAttribute(endpoint fwkscheduling.Endpoint, key string) internalinflight.LatencyPrediction {
if endpoint == nil {
return internalinflight.LatencyPrediction{}
}
value, ok := endpoint.Get(key)
if !ok {
return internalinflight.LatencyPrediction{}
}
prediction, ok := value.(*latencyattr.LatencyPredictionInfo)
if !ok || prediction == nil {
return internalinflight.LatencyPrediction{}
}
return internalinflight.LatencyPrediction{
TTFTMillis: prediction.TTFT(),
TPOTMillis: prediction.TPOT(),
}
}
func latencyPrediction(endpoint fwkscheduling.Endpoint) internalinflight.LatencyPrediction {
return readLatencyAttribute(endpoint, latencyattr.LatencyPredictionInfoKey)
}
func pdPrediction(prefill, decode fwkscheduling.Endpoint, key string) internalinflight.LatencyPrediction {
return internalinflight.LatencyPrediction{
TTFTMillis: readLatencyAttribute(prefill, key).TTFTMillis,
TPOTMillis: readLatencyAttribute(decode, key).TPOTMillis,
}
}
func selectedRoutingSnapshot(selected fwkscheduling.Endpoint) (
float64,
internalinflight.LatencyPrediction,
string,
) {
if prefillEP, decodeEP, ok := selectedPDRouteEndpoints(selected); ok {
active := pdPrediction(prefillEP, decodeEP, latencyattr.LatencyPredictionInfoKey)
return score.PrefixCacheScoreFromEndpoint(prefillEP), active, derivePredictionMode(active)
}
active := latencyPrediction(selected)
return score.PrefixCacheScoreFromEndpoint(selected), active, derivePredictionMode(active)
}
func derivePredictionMode(active internalinflight.LatencyPrediction) string {
zero := internalinflight.LatencyPrediction{}
if active != zero {
return internalinflight.PredictionModeActive
}
return ""
}
func outputLength(response *requestcontrol.Response) int {
if response == nil || response.Usage.CompletionTokens < 0 {
return 0
}
return response.Usage.CompletionTokens
}
func trackedRouteEndpoints(route *prerequest.Route) []*fwkdl.EndpointMetadata {
if !hasCompletePDRoute(route) {
return nil
}
return []*fwkdl.EndpointMetadata{route.Prefill, route.Decode}
}
func selectedRouteSnapshot(
store *internalinflight.Store,
selected fwkscheduling.Endpoint,
) *internalinflight.SelectedRoute {
if selected == nil {
return nil
}
if prefillEP, decodeEP, ok := selectedPDRouteEndpoints(selected); ok {
return &internalinflight.SelectedRoute{
Leader: endpointIdentitySnapshot(selected),
Prefill: endpointSnapshot(store, prefillEP),
Decode: endpointSnapshot(store, decodeEP),
}
}
return &internalinflight.SelectedRoute{Leader: endpointSnapshot(store, selected)}
}
func canTrackResponse(
tracker *RequestLifecycleTracker,
request *fwkscheduling.LLMRequest,
response *requestcontrol.Response,
) bool {
if tracker == nil || tracker.store == nil || request == nil {
return false
}
return request.RequestId != "" && response != nil
}
func selectedPDRouteEndpoints(selected fwkscheduling.Endpoint) (fwkscheduling.Endpoint, fwkscheduling.Endpoint, bool) {
info, ok := pdgroup.GetDirect(selected)
if !ok || info == nil {
return nil, nil, false
}
if info.SelectedPrefillPod == nil || info.SelectedDecodePod == nil {
return nil, nil, false
}
prefillEP := info.SelectedPrefillPod.Endpoint
decodeEP := info.SelectedDecodePod.Endpoint
return prefillEP, decodeEP, prefillEP != nil && decodeEP != nil
}
func endpointIdentitySnapshot(endpoint fwkscheduling.Endpoint) *internalinflight.RouteEndpointSnapshot {
if endpoint == nil {
return nil
}
return &internalinflight.RouteEndpointSnapshot{ID: endpointID(endpoint.GetMetadata())}
}
func endpointSnapshot(
store *internalinflight.Store,
endpoint fwkscheduling.Endpoint,
) *internalinflight.RouteEndpointSnapshot {
if endpoint == nil {
return nil
}
metadata := endpoint.GetMetadata()
load := store.Snapshot(metadata)
return &internalinflight.RouteEndpointSnapshot{
ID: endpointID(metadata),
Metrics: endpointMetricsSnapshot(endpoint.GetMetrics()),
Inflight: &internalinflight.Load{Tokens: load.Tokens, Requests: load.Requests},
NPU: endpointNPUSnapshot(endpoint),
}
}
func endpointID(metadata *fwkdl.EndpointMetadata) string {
if metadata == nil {
return ""
}
if metadata.PodName != "" {
return fmt.Sprintf("pod://%s/%s", metadata.NamespacedName.Namespace, metadata.PodName)
}
if metadata.Address == "" || metadata.Port == "" {
return ""
}
return fmt.Sprintf("addr://%s:%s", metadata.Address, metadata.Port)
}
func endpointMetricsSnapshot(metrics *fwkdl.Metrics) *internalinflight.EndpointMetricsSnapshot {
if metrics == nil {
return nil
}
return &internalinflight.EndpointMetricsSnapshot{
ActiveModels: maps.Clone(metrics.ActiveModels),
WaitingModels: maps.Clone(metrics.WaitingModels),
MaxActiveModels: metrics.MaxActiveModels,
RunningRequestsSize: metrics.RunningRequestsSize,
WaitingQueueSize: metrics.WaitingQueueSize,
KVCacheUsagePercent: metrics.KVCacheUsagePercent,
KVCacheMaxTokenCapacity: metrics.KvCacheMaxTokenCapacity,
CacheBlockSize: metrics.CacheBlockSize,
CacheNumBlocks: metrics.CacheNumBlocks,
UpdateTime: metrics.UpdateTime,
}
}
func endpointNPUSnapshot(endpoint fwkscheduling.Endpoint) *internalinflight.NPUStateSnapshot {
if endpoint == nil {
return nil
}
value, ok := endpoint.Get(npuattr.DefaultStateAttributeKey)
if !ok {
return nil
}
state, ok := value.(*npuattr.State)
if !ok || state == nil {
return nil
}
devices := make([]internalinflight.NPUDeviceSnapshot, 0, len(state.Devices))
for _, device := range state.Devices {
devices = append(devices, internalinflight.NPUDeviceSnapshot{
Target: device.Target,
DeviceID: device.DeviceID,
AICoreUtilization: device.AICoreUtilization,
HBMUtilization: device.HBMUtilization,
HBMBandwidthUtilization: device.HBMBandwidthUtilization,
})
}
return &internalinflight.NPUStateSnapshot{
ObservationTime: state.ObservationTime,
DataStatus: string(state.DataStatus),
DeviceCount: state.DeviceCount,
AICoreUtilizationAvg: state.AICoreUtilizationAvg,
AICoreUtilizationMax: state.AICoreUtilizationMax,
HBMUtilizationAvg: state.HBMUtilizationAvg,
HBMUtilizationMax: state.HBMUtilizationMax,
HBMBandwidthUtilizationAvg: state.HBMBandwidthUtilizationAvg,
HBMBandwidthUtilizationMax: state.HBMBandwidthUtilizationMax,
Devices: devices,
}
}
func registerShutdownClose(handle fwkplugin.Handle, store *internalinflight.Store) {
if handle == nil || store == nil {
return
}
ctx := handle.Context()
if ctx == nil {
return
}
go func() {
<-ctx.Done()
_ = store.Close()
}()
}
func storeLogger(handle fwkplugin.Handle, pluginName string, storeName string) logr.Logger {
if handle == nil {
return logr.Discard()
}
ctx := handle.Context()
if ctx == nil {
return logr.Discard()
}
return log.FromContext(ctx).WithName("RequestLifecycleTracker.Store").WithValues(
"plugin", pluginName,
"store", storeName,
)
}