* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* openFuyao is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
* http://license.coscl.org.cn/MulanPSL2
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
*/
package warmupjob
import (
"context"
"fmt"
"log/slog"
"net/http"
"path/filepath"
"slices"
"strings"
"time"
corev1 "k8s.io/api/core/v1"
apierrors "k8s.io/apimachinery/pkg/api/errors"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/runtime"
ctrl "sigs.k8s.io/controller-runtime"
"sigs.k8s.io/controller-runtime/pkg/client"
warmupv1alpha1 "github.com/openfuyao/weight-dispatcher/api/v1alpha1"
"github.com/openfuyao/weight-dispatcher/pkg/dataplane/rdma"
"github.com/openfuyao/weight-dispatcher/pkg/internal/errutil"
"github.com/openfuyao/weight-dispatcher/pkg/node"
"github.com/openfuyao/weight-dispatcher/pkg/planning/transferplanner"
sharedtypes "github.com/openfuyao/weight-dispatcher/pkg/types"
)
type AgentClient interface {
SubmitWarmup(ctx context.Context, node corev1.Node, req sharedtypes.SubmitWarmupRequest) (sharedtypes.TaskHandle, error)
GetWarmupTaskStatus(ctx context.Context, node corev1.Node, req sharedtypes.GetWarmupTaskStatusRequest) (sharedtypes.TaskStatus, error)
BuildManifest(ctx context.Context, node corev1.Node, req sharedtypes.BuildManifestRequest) (sharedtypes.BuildManifestResponse, error)
OpenCollective(ctx context.Context, node corev1.Node, req sharedtypes.OpenCollectiveRequest) (sharedtypes.OpenCollectiveResponse, error)
StepCollective(ctx context.Context, node corev1.Node, req sharedtypes.CollectiveStepRequest) (sharedtypes.CollectiveStepResponse, error)
CompleteCollective(ctx context.Context, node corev1.Node, req sharedtypes.CompleteCollectiveRequest) error
}
type Reconciler struct {
client.Client
Scheme *runtime.Scheme
Resolver *node.Resolver
Dispatcher transferplanner.Dispatcher
Agent AgentClient
RequeueAfter time.Duration
Logger *slog.Logger
}
type collectiveCompletionTarget struct {
nodeName string
taskID string
sessionID string
success bool
message string
stagingPath string
}
type pendingExecution struct {
node corev1.Node
intent transferplanner.WarmupNodeIntent
execution sharedtypes.WarmupExecutionPlan
openErr error
}
type reconcilePreparation struct {
job warmupv1alpha1.ModelWarmupJob
nodes []corev1.Node
targets []string
manifest sharedtypes.LogicalManifest
cachePlan transferplanner.CacheBuildPlan
}
type reconcileLoadResult struct {
job warmupv1alpha1.ModelWarmupJob
result ctrl.Result
done bool
}
type reconcileValidationResult struct {
result ctrl.Result
done bool
}
type targetResolutionResult struct {
nodes []corev1.Node
targets []string
result ctrl.Result
done bool
}
type reconcilePlanResult struct {
preparation reconcilePreparation
result ctrl.Result
done bool
}
const (
sourceTypeNode = "node"
sourceTypeExternal = "external"
)
func (r *Reconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.Result, error) {
logger := reconcileLogger(r.Logger)
plan, err := r.prepareReconcile(ctx, logger, req)
if err != nil || plan.done {
return plan.result, err
}
prepared := plan.preparation
existing := filterExistingNodeStates(prepared.job.Status.NodeStates, prepared.cachePlan.PlanID)
pendingExecutions, err := r.buildPendingExecutions(ctx, prepared.cachePlan, prepared.nodes, existing, prepared.job.Spec, prepared.manifest)
if err != nil {
return ctrl.Result{}, r.logAndWrap(logger, "build pending warmup executions failed", "build pending executions", err, "name", prepared.job.Name)
}
r.openPendingCollectives(ctx, logger, prepared.cachePlan.NodeIntents, pendingExecutions, prepared.job.Spec.Policy.TimeoutSeconds)
nextStates, err := r.reconcileNodeStates(ctx, logger, prepared.nodes, prepared.cachePlan.NodeIntents, existing, pendingExecutions)
if err != nil {
return ctrl.Result{}, r.logAndWrap(logger, "reconcile node warmup states failed", "reconcile node states", err, "name", prepared.job.Name)
}
if err := r.updateJobStatus(ctx, &prepared.job, prepared.cachePlan.PlanID, prepared.targets, nextStates); err != nil {
return ctrl.Result{}, r.logAndWrap(logger, "update warmup job status failed", "update job status", err, "name", prepared.job.Name)
}
if !isTerminal(prepared.job.Status.Phase) {
delay := r.RequeueAfter
if delay == 0 {
delay = 500 * time.Millisecond
}
return ctrl.Result{RequeueAfter: delay}, nil
}
r.completeCollectiveSessions(ctx, logger, prepared.nodes, prepared.cachePlan, nextStates)
return ctrl.Result{}, nil
}
func reconcileLogger(logger *slog.Logger) *slog.Logger {
if logger == nil {
return slog.Default()
}
return logger
}
func (r *Reconciler) prepareReconcile(
ctx context.Context,
logger *slog.Logger,
req ctrl.Request,
) (reconcilePlanResult, error) {
loadResult, err := r.loadWarmupJob(ctx, logger, req)
if err != nil || loadResult.done {
return reconcilePlanResult{result: loadResult.result, done: loadResult.done}, err
}
prepared, err := r.prepareReconcilePlan(ctx, logger, loadResult.job)
if err != nil || prepared.done {
return prepared, err
}
return prepared, nil
}
func (r *Reconciler) loadWarmupJob(
ctx context.Context,
logger *slog.Logger,
req ctrl.Request,
) (reconcileLoadResult, error) {
var job warmupv1alpha1.ModelWarmupJob
if err := r.Get(ctx, req.NamespacedName, &job); err != nil {
if apierrors.IsNotFound(err) {
return reconcileLoadResult{done: true}, nil
}
return reconcileLoadResult{done: true}, r.logAndWrap(logger, "get warmup job failed", "get warmup job", err, "namespace", req.Namespace, "name", req.Name)
}
logger.Debug("开始处理 WarmupJob", "namespace", job.Namespace, "name", job.Name, "generation", job.Generation)
if isTerminal(job.Status.Phase) && job.Status.ObservedGeneration == job.Generation {
return reconcileLoadResult{done: true}, nil
}
return reconcileLoadResult{job: job}, nil
}
func (r *Reconciler) prepareReconcilePlan(
ctx context.Context,
logger *slog.Logger,
job warmupv1alpha1.ModelWarmupJob,
) (reconcilePlanResult, error) {
validation, err := r.validateWarmupJob(ctx, logger, &job)
if err != nil || validation.done {
return reconcilePlanResult{result: validation.result, done: validation.done}, err
}
targetResolution, err := r.resolveTargetNodes(ctx, logger, &job)
if err != nil || targetResolution.done {
return reconcilePlanResult{result: targetResolution.result, done: targetResolution.done}, err
}
manifest, cachePlan, err := r.buildReconcileCachePlan(ctx, logger, job, targetResolution.nodes, targetResolution.targets)
if err != nil {
if updateErr := r.markPlanBuildFailed(ctx, &job, err); updateErr != nil {
return reconcilePlanResult{done: true}, r.logAndWrap(logger, "update warmup job status for plan build failure failed", "update warmup job status for plan build failure", updateErr, "namespace", job.Namespace, "name", job.Name)
}
return reconcilePlanResult{done: true}, err
}
resetStalePlanState(logger, &job, cachePlan.PlanID)
return reconcilePlanResult{
preparation: reconcilePreparation{
job: job,
nodes: targetResolution.nodes,
targets: targetResolution.targets,
manifest: manifest,
cachePlan: cachePlan,
},
}, nil
}
func (r *Reconciler) validateWarmupJob(
ctx context.Context,
logger *slog.Logger,
job *warmupv1alpha1.ModelWarmupJob,
) (reconcileValidationResult, error) {
err := validateJob(job.Spec)
if err == nil {
return reconcileValidationResult{}, nil
}
job.Status.Phase = warmupv1alpha1.JobPhaseFailed
job.Status.LastErrorCode = "ValidationFailed"
job.Status.LastErrorMessage = err.Error()
logger.Error("warmup job validation failed", "namespace", job.Namespace, "name", job.Name, "error", err)
if updateErr := r.Status().Update(ctx, job); updateErr != nil {
return reconcileValidationResult{done: true}, r.logAndWrap(logger, "update warmup job status after validation failure failed", "update warmup job status after validation failure", updateErr, "namespace", job.Namespace, "name", job.Name)
}
return reconcileValidationResult{done: true}, err
}
func (r *Reconciler) resolveTargetNodes(
ctx context.Context,
logger *slog.Logger,
job *warmupv1alpha1.ModelWarmupJob,
) (targetResolutionResult, error) {
nodes, err := r.Resolver.Resolve(ctx, job.Spec.Target)
if err != nil {
logger.Error("resolve target nodes failed", "namespace", job.Namespace, "name", job.Name, "error", err)
if updateErr := r.markNoTargetNodes(ctx, job, fmt.Sprintf("resolve target nodes: %v", err)); updateErr != nil {
return targetResolutionResult{done: true}, r.logAndWrap(logger, "update warmup job status for target resolution failure failed", "update warmup job status for target resolution failure", updateErr, "namespace", job.Namespace, "name", job.Name)
}
return targetResolutionResult{done: true}, nil
}
if len(nodes) == 0 {
if updateErr := r.markNoTargetNodes(ctx, job, "no target nodes matched spec.target"); updateErr != nil {
return targetResolutionResult{done: true}, r.logAndWrap(logger, "update warmup job status for empty target set failed", "update warmup job status for empty target set", updateErr, "namespace", job.Namespace, "name", job.Name)
}
return targetResolutionResult{done: true}, nil
}
return targetResolutionResult{
nodes: nodes,
targets: collectTargetNames(nodes),
}, nil
}
func (r *Reconciler) buildReconcileCachePlan(
ctx context.Context,
logger *slog.Logger,
job warmupv1alpha1.ModelWarmupJob,
nodes []corev1.Node,
targets []string,
) (sharedtypes.LogicalManifest, transferplanner.CacheBuildPlan, error) {
manifest, normalizedSources, err := r.resolveManifestAndSources(ctx, logger, job.Spec)
if err != nil {
return sharedtypes.LogicalManifest{}, transferplanner.CacheBuildPlan{}, err
}
cachePlan, err := r.Dispatcher.BuildCachePlan(ctx, transferplanner.CacheBuildRequest{
ArtifactKey: job.Spec.Artifact.Key,
ArtifactType: job.Spec.Artifact.Type,
WorkloadID: warmupJobPlanScope(job),
ExplicitSources: normalizedSources,
LogicalManifest: manifest,
TargetNodes: targets,
TargetRootPath: job.Spec.Target.TargetPath,
ChunkSizeMB: job.Spec.Policy.ChunkSizeMB,
EnableChunkCRC32C: job.Spec.Policy.EnableChunkCRC32C,
PublishAsSource: job.Spec.Policy.PublishAsSource,
TimeoutSeconds: job.Spec.Policy.TimeoutSeconds,
})
if err != nil {
return sharedtypes.LogicalManifest{}, transferplanner.CacheBuildPlan{}, r.logAndWrap(logger, "build cache plan failed", "build cache plan", err, "name", job.Name)
}
if err := r.enrichCollectivePlan(ctx, nodes, &cachePlan); err != nil {
return sharedtypes.LogicalManifest{}, transferplanner.CacheBuildPlan{}, r.logAndWrap(logger, "enrich collective plan failed", "enrich collective plan", err, "name", job.Name)
}
return manifest, cachePlan, nil
}
func warmupJobPlanScope(job warmupv1alpha1.ModelWarmupJob) string {
if job.UID != "" {
return string(job.UID)
}
return strings.Join([]string{
job.Namespace,
job.Name,
fmt.Sprintf("%d", job.Generation),
}, "/")
}
func (r *Reconciler) resolveManifestAndSources(
ctx context.Context,
logger *slog.Logger,
spec warmupv1alpha1.ModelWarmupJobSpec,
) (sharedtypes.LogicalManifest, []warmupv1alpha1.SourceSpec, error) {
normalizedSources, err := r.normalizeSources(ctx, spec.Sources)
if err != nil {
return sharedtypes.LogicalManifest{}, nil, r.logAndWrap(logger, "normalize sources failed", "normalize sources", err, "artifactKey", spec.Artifact.Key)
}
manifest, err := r.buildLogicalManifest(ctx, normalizedSources, spec)
if err != nil {
return sharedtypes.LogicalManifest{}, nil, r.logAndWrap(logger, "build logical manifest failed", "build logical manifest", err, "artifactKey", spec.Artifact.Key)
}
normalizedSources, err = r.resolveAndValidateSourceManifests(ctx, normalizedSources, spec, manifest)
if err != nil {
return sharedtypes.LogicalManifest{}, nil, r.logAndWrap(logger, "resolve and validate source manifests failed", "resolve and validate source manifests", err, "artifactKey", spec.Artifact.Key)
}
return manifest, normalizedSources, nil
}
func (r *Reconciler) markNoTargetNodes(ctx context.Context, job *warmupv1alpha1.ModelWarmupJob, message string) error {
job.Status.ObservedGeneration = job.Generation
job.Status.Phase = warmupv1alpha1.JobPhaseFailed
job.Status.LastErrorCode = "NoTargetNodes"
job.Status.LastErrorMessage = message
job.Status.ResolvedNodes = nil
job.Status.NodeStates = nil
job.Status.Summary = warmupv1alpha1.WarmupSummary{}
if err := r.Status().Update(ctx, job); err != nil {
return errutil.Wrap("update warmup job status for empty target set", err)
}
return nil
}
func (r *Reconciler) markPlanBuildFailed(ctx context.Context, job *warmupv1alpha1.ModelWarmupJob, cause error) error {
job.Status.ObservedGeneration = job.Generation
job.Status.Phase = warmupv1alpha1.JobPhaseFailed
job.Status.LastErrorCode = "PlanBuildFailed"
job.Status.LastErrorMessage = cause.Error()
job.Status.ResolvedNodes = nil
job.Status.NodeStates = nil
job.Status.Summary = warmupv1alpha1.WarmupSummary{}
return errutil.Wrap("update warmup job status for plan build failure", r.Status().Update(ctx, job))
}
func collectTargetNames(nodes []corev1.Node) []string {
targets := make([]string, 0, len(nodes))
for _, item := range nodes {
targets = append(targets, item.Name)
}
return targets
}
func resetStalePlanState(logger *slog.Logger, job *warmupv1alpha1.ModelWarmupJob, planID string) {
if job.Status.LastPlanID == "" || job.Status.LastPlanID == planID {
return
}
logger.Debug("检测到 Warmup 计划变更,重置旧任务视图", "name", job.Name, "oldPlanID", job.Status.LastPlanID, "newPlanID", planID)
job.Status.NodeStates = nil
}
func filterExistingNodeStates(states []warmupv1alpha1.WarmupNodeState, planID string) map[string]warmupv1alpha1.WarmupNodeState {
existing := make(map[string]warmupv1alpha1.WarmupNodeState, len(states))
for _, state := range states {
if state.TaskID != "" && !strings.HasPrefix(state.TaskID, planID+"-") {
continue
}
existing[state.NodeName] = state
}
return existing
}
func (*Reconciler) logAndWrap(logger *slog.Logger, message, action string, err error, attrs ...any) error {
logger.Error(message, append(attrs, "error", err)...)
return errutil.Wrap(action, err)
}
func (r *Reconciler) buildPendingExecutions(
ctx context.Context,
cachePlan transferplanner.CacheBuildPlan,
nodes []corev1.Node,
existing map[string]warmupv1alpha1.WarmupNodeState,
spec warmupv1alpha1.ModelWarmupJobSpec,
manifest sharedtypes.LogicalManifest,
) (map[string]pendingExecution, error) {
pendingExecutions := make(map[string]pendingExecution, len(cachePlan.NodeIntents))
for _, intent := range cachePlan.NodeIntents {
state := existing[intent.TargetNode]
if state.TaskID != "" {
continue
}
nodeObj, found := findNodeByName(nodes, intent.TargetNode)
if !found {
return nil, fmt.Errorf("target node %s not found in resolved node list", intent.TargetNode)
}
execution, err := r.Dispatcher.BuildWarmupExecution(ctx, transferplanner.BuildWarmupExecutionRequest{
PlanID: cachePlan.PlanID,
ArtifactKey: spec.Artifact.Key,
ArtifactType: spec.Artifact.Type,
TargetNode: intent.TargetNode,
TargetPath: intent.TargetPath,
ChunkSizeMB: spec.Policy.ChunkSizeMB,
EnableChunkCRC32C: spec.Policy.EnableChunkCRC32C,
PublishAsSource: spec.Policy.PublishAsSource,
TimeoutSeconds: spec.Policy.TimeoutSeconds,
LogicalManifest: manifest,
TargetPlan: intent.TargetPlan,
CollectiveSpec: intent.Collective,
})
if err != nil {
return nil, errutil.Wrap("build warmup execution", err)
}
pendingExecutions[intent.TargetNode] = pendingExecution{
node: nodeObj,
intent: intent,
execution: execution,
}
}
return pendingExecutions, nil
}
func (r *Reconciler) openPendingCollectives(
ctx context.Context,
logger *slog.Logger,
intents []transferplanner.WarmupNodeIntent,
pendingExecutions map[string]pendingExecution,
timeoutSeconds int32,
) {
if pendingExecutions == nil {
logger.Warn("skip opening collectives because pending executions map is nil")
return
}
for _, intent := range intents {
targetNode := intent.TargetNode
pending, ok := pendingExecutions[targetNode]
if !ok {
continue
}
if pending.execution.CollectiveSpec.Mode == sharedtypes.CollectiveModeNone {
continue
}
if cleanupErr := r.Agent.CompleteCollective(ctx, pending.node, sharedtypes.CompleteCollectiveRequest{
TaskID: pending.execution.TaskID,
SessionID: pending.execution.CollectiveSpec.SessionID,
Success: false,
Message: "pre-open cleanup sweep",
}); cleanupErr != nil {
logger.Warn("pre-open collective cleanup sweep failed", "targetNode", targetNode, "taskID", pending.execution.TaskID, "error", cleanupErr)
}
openResp, openErr := r.Agent.OpenCollective(ctx, pending.node, sharedtypes.OpenCollectiveRequest{
TaskID: pending.execution.TaskID,
SessionID: pending.execution.CollectiveSpec.SessionID,
ArtifactKey: pending.execution.ArtifactKey,
CollectiveSpec: pending.execution.CollectiveSpec,
EnableChunkCRC: pending.execution.EnableChunkCRC32C,
TimeoutSeconds: timeoutSeconds,
})
if openErr != nil {
pending.openErr = openErr
pendingExecutions[targetNode] = pending
continue
}
logger.Debug("collective 会话已打开", "targetNode", targetNode, "taskID", pending.execution.TaskID, "transportPath", openResp.TransportPath)
}
}
func (r *Reconciler) reconcileNodeStates(
ctx context.Context,
logger *slog.Logger,
nodes []corev1.Node,
intents []transferplanner.WarmupNodeIntent,
existing map[string]warmupv1alpha1.WarmupNodeState,
pendingExecutions map[string]pendingExecution,
) ([]warmupv1alpha1.WarmupNodeState, error) {
nextStates := make([]warmupv1alpha1.WarmupNodeState, 0, len(nodes))
for _, intent := range intents {
nodeObj, found := findNodeByName(nodes, intent.TargetNode)
if !found {
return nil, fmt.Errorf("target node %s not found in resolved node list", intent.TargetNode)
}
state := existing[intent.TargetNode]
if state.NodeName == "" {
state.NodeName = intent.TargetNode
}
if state.TaskID == "" {
state = r.handlePendingWarmupState(ctx, logger, nodeObj, intent, pendingExecutions[intent.TargetNode], state)
nextStates = append(nextStates, state)
continue
}
nextStates = append(nextStates, r.reconcileExistingNodeState(ctx, logger, nodeObj, intent, state))
}
return nextStates, nil
}
func (r *Reconciler) reconcileExistingNodeState(
ctx context.Context,
logger *slog.Logger,
nodeObj corev1.Node,
intent transferplanner.WarmupNodeIntent,
state warmupv1alpha1.WarmupNodeState,
) warmupv1alpha1.WarmupNodeState {
if isTerminal(state.Phase) {
return state
}
status, getErr := r.Agent.GetWarmupTaskStatus(ctx, nodeObj, sharedtypes.GetWarmupTaskStatusRequest{TaskID: state.TaskID})
if getErr != nil {
if IsAgentHTTPStatus(getErr, http.StatusNotFound) {
logger.Warn("warmup task disappeared on node-agent, clearing task state for resubmit", "targetNode", intent.TargetNode, "taskID", state.TaskID)
state.TaskID = ""
state.Phase = warmupv1alpha1.JobPhasePending
state.StartedAt = nil
state.FinishedAt = nil
state.BytesTransferred = 0
state.ThroughputMBps = 0
state.CachePath = ""
}
state.Message = getErr.Error()
if state.Phase == "" {
state.Phase = warmupv1alpha1.JobPhaseRunning
}
return state
}
applyTaskStatus(&state, status)
return state
}
func (r *Reconciler) updateJobStatus(
ctx context.Context,
job *warmupv1alpha1.ModelWarmupJob,
planID string,
targetNames []string,
nextStates []warmupv1alpha1.WarmupNodeState,
) error {
slices.Sort(targetNames)
job.Status.ObservedGeneration = job.Generation
job.Status.ResolvedNodes = targetNames
job.Status.NodeStates = nextStates
job.Status.Summary = summarize(nextStates)
job.Status.Phase = aggregatePhase(job.Status.Summary)
job.Status.LastPlanID = planID
return errutil.Wrap("update warmup job status", r.Status().Update(ctx, job))
}
func (r *Reconciler) completeCollectiveSessions(
ctx context.Context,
logger *slog.Logger,
nodes []corev1.Node,
cachePlan transferplanner.CacheBuildPlan,
nextStates []warmupv1alpha1.WarmupNodeState,
) {
for _, target := range collectiveCompletionTargets(cachePlan, nextStates) {
nodeObj, found := findNodeByName(nodes, target.nodeName)
if !found {
logger.Warn("skip collective cleanup because target node is no longer resolved", "targetNode", target.nodeName, "taskID", target.taskID)
continue
}
if completeErr := r.Agent.CompleteCollective(ctx, nodeObj, sharedtypes.CompleteCollectiveRequest{
TaskID: target.taskID,
SessionID: target.sessionID,
Success: target.success,
Message: target.message,
StagingPath: target.stagingPath,
}); completeErr != nil {
logger.Warn("complete collective session failed", "targetNode", target.nodeName, "taskID", target.taskID, "error", completeErr)
}
}
}
func (r *Reconciler) buildLogicalManifest(ctx context.Context, sources []warmupv1alpha1.SourceSpec, spec warmupv1alpha1.ModelWarmupJobSpec) (sharedtypes.LogicalManifest, error) {
manifestSource, nodeObj, err := r.selectManifestSource(ctx, sources)
if err != nil {
return sharedtypes.LogicalManifest{}, errutil.Wrap("select manifest source", err)
}
chunkSizeBytes := int64(spec.Policy.ChunkSizeMB) * 1024 * 1024
if chunkSizeBytes <= 0 {
chunkSizeBytes = 64 * 1024 * 1024
}
if manifestSource.SourceType != sourceTypeNode {
if !rdma.IsHuggingFaceEndpoint(manifestSource.SourceType, manifestSource.Endpoint) {
return sharedtypes.LogicalManifest{}, fmt.Errorf("phase one does not support sourceType=%s endpoint=%s", manifestSource.SourceType, manifestSource.Endpoint)
}
hfManifest, hfErr := rdma.ResolveHuggingFaceManifest(ctx, nil, rdma.HuggingFaceManifestRequest{
Endpoint: manifestSource.Endpoint,
ModelID: manifestSource.Path,
Revision: rdma.ExtractHFRevision(manifestSource.Endpoint),
ChunkSizeBytes: chunkSizeBytes,
})
if hfErr != nil {
return sharedtypes.LogicalManifest{}, errutil.Wrap("resolve huggingface manifest", hfErr)
}
return hfManifest, nil
}
resp, err := r.Agent.BuildManifest(ctx, nodeObj, sharedtypes.BuildManifestRequest{
ArtifactKey: spec.Artifact.Key,
Source: sharedtypes.SourceEndpoint{
SourceID: strings.Join([]string{manifestSource.SourceType, manifestSource.NodeName, manifestSource.Endpoint, manifestSource.Path}, "|"),
SourceType: manifestSource.SourceType,
NodeName: manifestSource.NodeName,
Endpoint: manifestSource.Endpoint,
Path: manifestSource.Path,
},
ChunkSizeBytes: chunkSizeBytes,
})
if err != nil {
return sharedtypes.LogicalManifest{}, errutil.Wrap("build manifest through agent", err)
}
return resp.Manifest, nil
}
func alignSourcesToManifest(sources []warmupv1alpha1.SourceSpec, manifest sharedtypes.LogicalManifest) []warmupv1alpha1.SourceSpec {
if manifest.RootPath == "" {
return sources
}
cleanManifestRoot := cleanAgentPath(manifest.RootPath)
aligned := make([]warmupv1alpha1.SourceSpec, 0, len(sources))
for _, source := range sources {
item := source
cleanSourcePath := cleanAgentPath(source.Path)
if len(manifest.Files) == 1 {
if shouldAlignSingleManifestFile(manifest, cleanSourcePath) {
item.Path = cleanManifestRoot
}
aligned = append(aligned, item)
continue
}
if isPathWithin(cleanManifestRoot, cleanSourcePath) {
item.Path = cleanManifestRoot
}
aligned = append(aligned, item)
}
return aligned
}
func shouldAlignSingleManifestFile(manifest sharedtypes.LogicalManifest, cleanSourcePath string) bool {
fileName := filepath.Base(manifest.Files[0].RelativePath)
return fileName != "." && fileName != "" && filepath.Base(cleanSourcePath) == fileName
}
func (r *Reconciler) resolveAndValidateSourceManifests(
ctx context.Context,
sources []warmupv1alpha1.SourceSpec,
spec warmupv1alpha1.ModelWarmupJobSpec,
manifest sharedtypes.LogicalManifest,
) ([]warmupv1alpha1.SourceSpec, error) {
aligned := alignSourcesToManifest(sources, manifest)
chunkSizeBytes := manifestChunkSizeBytes(manifest, spec)
expectedSignature := manifestLayoutSignature(manifest)
resolved := make([]warmupv1alpha1.SourceSpec, 0, len(aligned))
for _, source := range aligned {
item, err := r.resolveValidatedSource(ctx, source, spec, chunkSizeBytes, expectedSignature, manifest)
if err != nil {
return nil, err
}
resolved = append(resolved, item)
}
return resolved, nil
}
func manifestChunkSizeBytes(manifest sharedtypes.LogicalManifest, spec warmupv1alpha1.ModelWarmupJobSpec) int64 {
if manifest.ChunkSizeBytes > 0 {
return manifest.ChunkSizeBytes
}
if spec.Policy.ChunkSizeMB > 0 {
return int64(spec.Policy.ChunkSizeMB) * 1024 * 1024
}
return 64 * 1024 * 1024
}
func (r *Reconciler) resolveValidatedSource(
ctx context.Context,
source warmupv1alpha1.SourceSpec,
spec warmupv1alpha1.ModelWarmupJobSpec,
chunkSizeBytes int64,
expectedSignature string,
manifest sharedtypes.LogicalManifest,
) (warmupv1alpha1.SourceSpec, error) {
if source.SourceType != sourceTypeNode || source.NodeName == "" {
return source, nil
}
nodeObj, err := r.Resolver.GetNode(ctx, source.NodeName)
if err != nil {
return warmupv1alpha1.SourceSpec{}, errutil.Wrap("resolve source node", err)
}
resp, err := r.Agent.BuildManifest(ctx, *nodeObj, sharedtypes.BuildManifestRequest{
ArtifactKey: spec.Artifact.Key,
Source: sharedtypes.SourceEndpoint{
SourceType: source.SourceType,
NodeName: source.NodeName,
Endpoint: source.Endpoint,
Path: source.Path,
},
ChunkSizeBytes: chunkSizeBytes,
})
if err != nil {
return warmupv1alpha1.SourceSpec{}, fmt.Errorf("validate source %s (%s): %w", source.NodeName, source.Path, err)
}
if !manifestLayoutsMatch(manifest, resp.Manifest) {
return warmupv1alpha1.SourceSpec{}, fmt.Errorf(
"source %s manifest mismatch for %s: expected signature %s, got %s",
source.NodeName,
source.Path,
expectedSignature,
manifestLayoutSignature(resp.Manifest),
)
}
item := source
if resp.Manifest.RootPath != "" {
item.Path = cleanAgentPath(resp.Manifest.RootPath)
}
return item, nil
}
func manifestLayoutsMatch(expected, actual sharedtypes.LogicalManifest) bool {
if len(expected.Files) != len(actual.Files) {
return false
}
for index, file := range expected.Files {
candidate := actual.Files[index]
if !artifactFileLayoutEqual(file, candidate) {
return false
}
}
return true
}
func artifactFileLayoutEqual(expected, actual sharedtypes.ArtifactFile) bool {
if expected.RelativePath != actual.RelativePath || expected.SizeBytes != actual.SizeBytes {
return false
}
if expected.Kind != actual.Kind || expected.Chunkable != actual.Chunkable {
return false
}
return expected.Required == actual.Required
}
func manifestLayoutSignature(manifest sharedtypes.LogicalManifest) string {
var builder strings.Builder
for _, file := range manifest.Files {
fmt.Fprintf(
&builder,
"%s|%d|%s|%t|%t\n",
file.RelativePath,
file.SizeBytes,
file.Kind,
file.Chunkable,
file.Required,
)
}
return builder.String()
}
func (r *Reconciler) selectManifestSource(ctx context.Context, sources []warmupv1alpha1.SourceSpec) (warmupv1alpha1.SourceSpec, corev1.Node, error) {
for _, source := range sources {
if source.SourceType != sourceTypeNode || source.NodeName == "" {
continue
}
nodeObj, err := r.Resolver.GetNode(ctx, source.NodeName)
if err != nil {
return warmupv1alpha1.SourceSpec{}, corev1.Node{}, errutil.Wrap("resolve manifest source node", err)
}
return source, *nodeObj, nil
}
for _, source := range sources {
if source.SourceType == "external" && rdma.IsHuggingFaceEndpoint(source.SourceType, source.Endpoint) {
return source, corev1.Node{}, nil
}
}
return warmupv1alpha1.SourceSpec{}, corev1.Node{}, fmt.Errorf("phase one requires sourceType=node or a supported external manifest source")
}
func (*Reconciler) enrichCollectivePlan(_ context.Context, nodes []corev1.Node, cachePlan *transferplanner.CacheBuildPlan) error {
if cachePlan == nil {
return nil
}
nodeByName := make(map[string]corev1.Node, len(nodes))
for _, item := range nodes {
nodeByName[item.Name] = item
}
peerPlans, err := buildCollectivePeerPlans(cachePlan.NodeIntents, nodeByName)
if err != nil {
return err
}
if len(peerPlans) == 0 {
return nil
}
return applyCollectivePeers(cachePlan, peerPlans)
}
func collectOwnedRanges(segments []sharedtypes.SourceSegmentPlan) []sharedtypes.ByteRange {
if len(segments) == 0 {
return nil
}
ranges := make([]sharedtypes.ByteRange, 0)
for _, segment := range segments {
ranges = append(ranges, segment.ByteRanges...)
}
return append([]sharedtypes.ByteRange(nil), ranges...)
}
func buildCollectivePeerPlans(
intents []transferplanner.WarmupNodeIntent,
nodeByName map[string]corev1.Node,
) (map[string][]sharedtypes.CollectivePeerPlan, error) {
peerPlans := make(map[string][]sharedtypes.CollectivePeerPlan, len(intents))
for _, intent := range intents {
if intent.TargetPlan.TransferMode != sharedtypes.TransferModePartialPullAllGather {
continue
}
nodeObj, ok := nodeByName[intent.TargetNode]
if !ok {
return nil, fmt.Errorf("collective target node %s not found", intent.TargetNode)
}
endpoint, err := node.ExtractNodeInternalIP(&nodeObj)
if err != nil {
return nil, errutil.Wrap("extract node internal ip", err)
}
peerPlans[intent.TargetNode] = []sharedtypes.CollectivePeerPlan{{
NodeName: intent.TargetNode,
Endpoint: endpoint,
Rank: intent.Collective.Ring.Rank,
StagingPath: deriveCollectiveStagingPath(intent.TargetPath),
OwnedRanges: collectOwnedRanges(intent.TargetPlan.SourceSegments),
}}
}
return peerPlans, nil
}
func applyCollectivePeers(cachePlan *transferplanner.CacheBuildPlan, peerPlans map[string][]sharedtypes.CollectivePeerPlan) error {
allPeers := flattenCollectivePeers(peerPlans)
for idx := range cachePlan.NodeIntents {
intent := &cachePlan.NodeIntents[idx]
if intent.TargetPlan.TransferMode != sharedtypes.TransferModePartialPullAllGather || intent.Collective.Ring == nil {
continue
}
intent.Collective.Peers = append([]sharedtypes.CollectivePeerPlan(nil), allPeers...)
populateRingEndpoints(&intent.Collective, allPeers)
}
return nil
}
func flattenCollectivePeers(peerPlans map[string][]sharedtypes.CollectivePeerPlan) []sharedtypes.CollectivePeerPlan {
allPeers := make([]sharedtypes.CollectivePeerPlan, 0, len(peerPlans))
for _, peers := range peerPlans {
allPeers = append(allPeers, peers[0])
}
return allPeers
}
func populateRingEndpoints(collective *sharedtypes.CollectiveSpec, allPeers []sharedtypes.CollectivePeerPlan) {
for _, peer := range allPeers {
switch peer.NodeName {
case collective.Ring.SelfNode:
collective.Ring.SelfEndpoint = peer.Endpoint
case collective.Ring.PrevNode:
collective.Ring.PrevEndpoint = peer.Endpoint
case collective.Ring.NextNode:
collective.Ring.NextEndpoint = peer.Endpoint
default:
continue
}
}
}
func (r *Reconciler) handlePendingWarmupState(
ctx context.Context,
logger *slog.Logger,
nodeObj corev1.Node,
intent transferplanner.WarmupNodeIntent,
pending pendingExecution,
state warmupv1alpha1.WarmupNodeState,
) warmupv1alpha1.WarmupNodeState {
execution := pending.execution
if pending.openErr != nil {
state.Phase = warmupv1alpha1.JobPhasePending
state.Message = pending.openErr.Error()
state.FinishedAt = nil
return state
}
handle, submitErr := r.Agent.SubmitWarmup(ctx, nodeObj, sharedtypes.SubmitWarmupRequest{Plan: execution})
if submitErr != nil {
r.completeFailedCollective(ctx, logger, nodeObj, intent.TargetNode, execution, submitErr)
state.Phase = warmupv1alpha1.JobPhasePending
state.Message = submitErr.Error()
state.FinishedAt = nil
return state
}
now := metav1.NewTime(time.Now().UTC())
state.TaskID = handle.TaskID
state.Phase = warmupv1alpha1.JobPhaseRunning
state.StartedAt = &now
return state
}
func (r *Reconciler) completeFailedCollective(
ctx context.Context,
logger *slog.Logger,
nodeObj corev1.Node,
targetNode string,
execution sharedtypes.WarmupExecutionPlan,
submitErr error,
) {
if execution.CollectiveSpec.Mode == sharedtypes.CollectiveModeNone {
return
}
if completeErr := r.Agent.CompleteCollective(ctx, nodeObj, sharedtypes.CompleteCollectiveRequest{
TaskID: execution.TaskID,
SessionID: execution.CollectiveSpec.SessionID,
Success: false,
Message: submitErr.Error(),
}); completeErr != nil {
logger.Warn("submit warmup failed after collective open, cleanup collective session failed", "targetNode", targetNode, "taskID", execution.TaskID, "error", completeErr)
}
}
func collectiveCompletionTargets(cachePlan transferplanner.CacheBuildPlan, states []warmupv1alpha1.WarmupNodeState) []collectiveCompletionTarget {
if len(cachePlan.NodeIntents) == 0 || len(states) == 0 {
return nil
}
stateByNode := make(map[string]warmupv1alpha1.WarmupNodeState, len(states))
for _, state := range states {
if state.NodeName == "" {
continue
}
stateByNode[state.NodeName] = state
}
targets := make([]collectiveCompletionTarget, 0, len(cachePlan.NodeIntents))
for _, intent := range cachePlan.NodeIntents {
if intent.Collective.Mode == sharedtypes.CollectiveModeNone {
continue
}
state, ok := stateByNode[intent.TargetNode]
if !ok || state.TaskID == "" || !isTerminal(state.Phase) {
return nil
}
targets = append(targets, collectiveCompletionTarget{
nodeName: intent.TargetNode,
taskID: state.TaskID,
sessionID: intent.Collective.SessionID,
success: state.Phase == warmupv1alpha1.JobPhaseSucceeded,
message: state.Message,
stagingPath: deriveCollectiveStagingPath(intent.TargetPath),
})
}
return targets
}
func deriveCollectiveStagingPath(readyPath string) string {
cleaned := filepath.Clean(readyPath)
return filepath.Join(filepath.Dir(cleaned), ".staging", filepath.Base(cleaned))
}
func (r *Reconciler) SetupWithManager(mgr ctrl.Manager) error {
return errutil.Wrap("complete controller setup", ctrl.NewControllerManagedBy(mgr).For(&warmupv1alpha1.ModelWarmupJob{}).Complete(r))
}
func (r *Reconciler) normalizeSources(ctx context.Context, sources []warmupv1alpha1.SourceSpec) ([]warmupv1alpha1.SourceSpec, error) {
result := make([]warmupv1alpha1.SourceSpec, 0, len(sources))
seen := make(map[string]struct{}, len(sources))
for _, source := range sources {
item := source
if item.SourceType == sourceTypeNode {
if item.NodeName == "" {
return nil, fmt.Errorf("sourceType=node requires nodeName")
}
nodeObj, err := r.Resolver.GetNode(ctx, item.NodeName)
if err != nil {
return nil, errutil.Wrap("resolve normalized source node", err)
}
endpoint, err := node.ExtractNodeInternalIP(nodeObj)
if err != nil {
return nil, errutil.Wrap("extract normalized source endpoint", err)
}
item.Endpoint = endpoint
}
key := strings.Join([]string{item.SourceType, item.NodeName, item.Endpoint, item.Path}, "|")
if _, ok := seen[key]; ok {
return nil, fmt.Errorf("duplicated normalized source %s", key)
}
seen[key] = struct{}{}
result = append(result, item)
}
return result, nil
}
func findNodeByName(nodes []corev1.Node, name string) (corev1.Node, bool) {
for _, item := range nodes {
if item.Name == name {
return item, true
}
}
return corev1.Node{}, false
}
func validateJob(spec warmupv1alpha1.ModelWarmupJobSpec) error {
if err := validateJobTarget(spec); err != nil {
return err
}
hasNodeSource, err := validateJobSources(spec.Sources)
if err != nil {
return err
}
if hasNodeSource {
return nil
}
return validateExternalManifestSource(spec.Sources)
}
func validateJobTarget(spec warmupv1alpha1.ModelWarmupJobSpec) error {
if spec.Artifact.Key == "" {
return fmt.Errorf("spec.artifact.key is required")
}
if spec.Target.TargetPath == "" {
return fmt.Errorf("spec.target.targetPath is required")
}
if len(spec.Target.NodeNames) == 0 && len(spec.Target.NodeSelector) == 0 {
return fmt.Errorf("spec.target.nodeNames or spec.target.nodeSelector is required")
}
if len(spec.Sources) == 0 {
return fmt.Errorf("spec.sources must not be empty in phase one")
}
return nil
}
func validateJobSources(sources []warmupv1alpha1.SourceSpec) (bool, error) {
seen := map[string]struct{}{}
hasNodeSource := false
for _, source := range sources {
if err := validateOneSource(source); err != nil {
return false, err
}
if source.SourceType == sourceTypeNode {
hasNodeSource = true
}
key := strings.Join([]string{source.SourceType, source.NodeName, source.Endpoint, source.Path}, "|")
if _, ok := seen[key]; ok {
return false, fmt.Errorf("duplicated source %s", key)
}
seen[key] = struct{}{}
}
return hasNodeSource, nil
}
func validateOneSource(source warmupv1alpha1.SourceSpec) error {
if source.SourceType == "" || source.Path == "" {
return fmt.Errorf("sourceType and path are required for every source")
}
if source.SourceType == sourceTypeNode {
if source.NodeName == "" {
return fmt.Errorf("sourceType=node requires nodeName")
}
if source.Endpoint != "" {
return fmt.Errorf("sourceType=node must not set endpoint explicitly")
}
return nil
}
if source.Endpoint == "" {
return fmt.Errorf("sourceType=%s requires endpoint", source.SourceType)
}
return nil
}
func validateExternalManifestSource(sources []warmupv1alpha1.SourceSpec) error {
for _, source := range sources {
if source.SourceType == sourceTypeExternal && rdma.IsHuggingFaceEndpoint(source.SourceType, source.Endpoint) {
return nil
}
}
return fmt.Errorf("phase one requires at least one node source or a supported external manifest source")
}
func applyTaskStatus(state *warmupv1alpha1.WarmupNodeState, status sharedtypes.TaskStatus) {
state.TaskID = status.TaskID
state.Phase = string(status.Phase)
state.Message = status.Message
state.BytesTransferred = status.ProgressBytes
state.ThroughputMBps = status.ThroughputMBps
state.TransportPath = string(status.TransportPath)
state.CachePath = status.CachePath
if status.StartedAt != nil {
started := metav1.NewTime(status.StartedAt.UTC())
state.StartedAt = &started
}
if status.FinishedAt != nil {
finished := metav1.NewTime(status.FinishedAt.UTC())
state.FinishedAt = &finished
}
}
func summarize(states []warmupv1alpha1.WarmupNodeState) warmupv1alpha1.WarmupSummary {
summary := warmupv1alpha1.WarmupSummary{}
for _, state := range states {
summary.Total++
switch state.Phase {
case warmupv1alpha1.JobPhaseSucceeded:
summary.Succeeded++
case warmupv1alpha1.JobPhaseFailed, string(sharedtypes.WarmupPhaseCanceled):
summary.Failed++
case warmupv1alpha1.JobPhaseRunning:
summary.Running++
default:
summary.Pending++
}
}
return summary
}
func aggregatePhase(summary warmupv1alpha1.WarmupSummary) string {
if summary.Total == 0 {
return warmupv1alpha1.JobPhasePending
}
if summary.Running > 0 || summary.Pending > 0 {
return warmupv1alpha1.JobPhaseRunning
}
if summary.Failed > 0 {
return warmupv1alpha1.JobPhaseFailed
}
if summary.Succeeded == summary.Total {
return warmupv1alpha1.JobPhaseSucceeded
}
if summary.Succeeded > 0 {
return warmupv1alpha1.JobPhaseRunning
}
return warmupv1alpha1.JobPhasePending
}
func isTerminal(phase string) bool {
return phase == warmupv1alpha1.JobPhaseSucceeded || phase == warmupv1alpha1.JobPhaseFailed
}