* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* openFuyao is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
* http://license.coscl.org.cn/MulanPSL2
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
*/
package warmupjob
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"log/slog"
"net"
"net/http"
"time"
corev1 "k8s.io/api/core/v1"
"github.com/openfuyao/weight-dispatcher/pkg/internal/errutil"
nodemeta "github.com/openfuyao/weight-dispatcher/pkg/node"
sharedtypes "github.com/openfuyao/weight-dispatcher/pkg/types"
)
type HTTPAgentClient struct {
client *http.Client
logger *slog.Logger
port int
}
type nodeAgentCallSpec struct {
path string
operation string
buildURLMessage string
callFailureMessage string
extraAttrs []any
}
const (
logAttrError = "error"
logAttrNode = "node"
logAttrTask = "taskID"
logAttrURL = "url"
)
type AgentHTTPError struct {
Operation string
StatusCode int
}
func (e *AgentHTTPError) Error() string {
if e == nil {
return ""
}
return fmt.Sprintf("%s failed with status %d", e.Operation, e.StatusCode)
}
func IsAgentHTTPStatus(err error, statusCode int) bool {
var target *AgentHTTPError
return errors.As(err, &target) && target.StatusCode == statusCode
}
func NewHTTPAgentClient(port int, timeout time.Duration, logger *slog.Logger) *HTTPAgentClient {
if timeout <= 0 {
timeout = 10 * time.Second
}
if logger == nil {
logger = slog.Default()
}
return &HTTPAgentClient{
client: &http.Client{Timeout: timeout},
logger: logger,
port: port,
}
}
func (c *HTTPAgentClient) SubmitWarmup(ctx context.Context, node corev1.Node, req sharedtypes.SubmitWarmupRequest) (sharedtypes.TaskHandle, error) {
return callNodeAgentPostJSON[sharedtypes.SubmitWarmupRequest, sharedtypes.TaskHandle](
ctx,
c,
node,
req,
nodeAgentCallSpec{
path: "/v1/warmups",
operation: "submit warmup",
buildURLMessage: "构建 node-agent Warmup URL 失败",
callFailureMessage: "向 node-agent 提交 Warmup 任务失败",
extraAttrs: []any{logAttrTask, req.Plan.TaskID},
},
)
}
func (c *HTTPAgentClient) BuildManifest(ctx context.Context, node corev1.Node, req sharedtypes.BuildManifestRequest) (sharedtypes.BuildManifestResponse, error) {
return callNodeAgentPostJSON[sharedtypes.BuildManifestRequest, sharedtypes.BuildManifestResponse](
ctx,
c,
node,
req,
nodeAgentCallSpec{
path: "/v1/manifests:build",
operation: "build manifest",
buildURLMessage: "构建 node-agent manifest URL 失败",
callFailureMessage: "向 node-agent 构建逻辑 manifest 失败",
extraAttrs: []any{"artifactKey", req.ArtifactKey},
},
)
}
func (c *HTTPAgentClient) GetWarmupTaskStatus(ctx context.Context, node corev1.Node, req sharedtypes.GetWarmupTaskStatusRequest) (sharedtypes.TaskStatus, error) {
return callNodeAgentGetJSON[sharedtypes.TaskStatus](
ctx,
c,
node,
nodeAgentCallSpec{
path: "/v1/warmups/" + req.TaskID,
operation: "get warmup status",
buildURLMessage: "构建 node-agent 任务状态 URL 失败",
callFailureMessage: "获取 node-agent 任务状态失败",
extraAttrs: []any{logAttrTask, req.TaskID},
},
)
}
func (c *HTTPAgentClient) OpenCollective(ctx context.Context, node corev1.Node, req sharedtypes.OpenCollectiveRequest) (sharedtypes.OpenCollectiveResponse, error) {
return callNodeAgentPostJSON[sharedtypes.OpenCollectiveRequest, sharedtypes.OpenCollectiveResponse](
ctx,
c,
node,
req,
nodeAgentCallSpec{
path: "/v1/collectives/open",
operation: "open collective",
buildURLMessage: "构建 node-agent collective open URL 失败",
callFailureMessage: "打开 node-agent collective 会话失败",
extraAttrs: []any{logAttrTask, req.TaskID},
},
)
}
func (c *HTTPAgentClient) StepCollective(ctx context.Context, node corev1.Node, req sharedtypes.CollectiveStepRequest) (sharedtypes.CollectiveStepResponse, error) {
return callNodeAgentPostJSON[sharedtypes.CollectiveStepRequest, sharedtypes.CollectiveStepResponse](
ctx,
c,
node,
req,
nodeAgentCallSpec{
path: "/v1/collectives/step",
operation: "step collective",
buildURLMessage: "构建 node-agent collective step URL 失败",
callFailureMessage: "推进 node-agent collective 失败",
extraAttrs: []any{logAttrTask, req.TaskID},
},
)
}
func (c *HTTPAgentClient) CompleteCollective(ctx context.Context, node corev1.Node, req sharedtypes.CompleteCollectiveRequest) error {
return callNodeAgentPostNoContent(
ctx,
c,
node,
req,
nodeAgentCallSpec{
path: "/v1/collectives/complete",
operation: "complete collective",
buildURLMessage: "构建 node-agent collective complete URL 失败",
callFailureMessage: "关闭 node-agent collective 会话失败",
extraAttrs: []any{logAttrTask, req.TaskID},
},
)
}
func callNodeAgentPostJSON[Req, Resp any](
ctx context.Context,
client *HTTPAgentClient,
node corev1.Node,
req Req,
spec nodeAgentCallSpec,
) (Resp, error) {
url, err := client.nodeURL(&node, spec.path)
if err != nil {
client.logger.Error(spec.buildURLMessage, append([]any{logAttrNode, node.Name, logAttrError, err}, spec.extraAttrs...)...)
return *new(Resp), errutil.Wrap("build "+spec.operation+" URL", err)
}
response, callErr := postJSON[Req, Resp](ctx, client.client, url, req, spec.operation)
if callErr != nil {
client.logger.Error(spec.callFailureMessage, append([]any{logAttrNode, node.Name, logAttrURL, url, logAttrError, callErr}, spec.extraAttrs...)...)
}
return response, callErr
}
func callNodeAgentGetJSON[Resp any](
ctx context.Context,
client *HTTPAgentClient,
node corev1.Node,
spec nodeAgentCallSpec,
) (Resp, error) {
url, err := client.nodeURL(&node, spec.path)
if err != nil {
client.logger.Error(spec.buildURLMessage, append([]any{logAttrNode, node.Name, logAttrError, err}, spec.extraAttrs...)...)
return *new(Resp), errutil.Wrap("build "+spec.operation+" URL", err)
}
response, callErr := getJSON[Resp](ctx, client.client, url, spec.operation)
if callErr != nil {
client.logger.Error(spec.callFailureMessage, append([]any{logAttrNode, node.Name, logAttrURL, url, logAttrError, callErr}, spec.extraAttrs...)...)
}
return response, callErr
}
func callNodeAgentPostNoContent[Req any](
ctx context.Context,
client *HTTPAgentClient,
node corev1.Node,
req Req,
spec nodeAgentCallSpec,
) error {
url, err := client.nodeURL(&node, spec.path)
if err != nil {
client.logger.Error(spec.buildURLMessage, append([]any{logAttrNode, node.Name, logAttrError, err}, spec.extraAttrs...)...)
return errutil.Wrap("build "+spec.operation+" URL", err)
}
callErr := postNoContent(ctx, client.client, url, req, spec.operation)
if callErr != nil {
client.logger.Error(spec.callFailureMessage, append([]any{logAttrNode, node.Name, logAttrURL, url, logAttrError, callErr}, spec.extraAttrs...)...)
}
return callErr
}
func postJSON[Req, Resp any](ctx context.Context, client *http.Client, url string, req Req, operation string) (Resp, error) {
body, err := json.Marshal(req)
if err != nil {
return *new(Resp), errutil.Wrap("marshal "+operation+" request", err)
}
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
if err != nil {
return *new(Resp), errutil.Wrap("build "+operation+" request", err)
}
httpReq.Header.Set("Content-Type", "application/json")
resp, err := client.Do(httpReq)
if err != nil {
return *new(Resp), errutil.Wrap("send "+operation+" request", err)
}
return decodeJSONResponse[Resp](resp, operation)
}
func getJSON[Resp any](ctx context.Context, client *http.Client, url, operation string) (Resp, error) {
httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, url, http.NoBody)
if err != nil {
return *new(Resp), errutil.Wrap("build "+operation+" request", err)
}
resp, err := client.Do(httpReq)
if err != nil {
return *new(Resp), errutil.Wrap("send "+operation+" request", err)
}
return decodeJSONResponse[Resp](resp, operation)
}
func decodeJSONResponse[Resp any](resp *http.Response, operation string) (Resp, error) {
var output Resp
resultErr := decodeJSONResponseBody(resp, operation, &output)
closeErr := resp.Body.Close()
if resultErr == nil && closeErr != nil {
resultErr = errutil.Wrap("close "+operation+" response body", closeErr)
}
if resultErr != nil {
return *new(Resp), resultErr
}
return output, nil
}
func decodeJSONResponseBody[Resp any](resp *http.Response, operation string, output *Resp) error {
if resp.StatusCode >= 300 {
return &AgentHTTPError{Operation: operation, StatusCode: resp.StatusCode}
}
if err := json.NewDecoder(resp.Body).Decode(output); err != nil {
return errutil.Wrap("decode "+operation+" response", err)
}
return nil
}
func postNoContent[Req any](ctx context.Context, client *http.Client, url string, req Req, operation string) error {
body, err := json.Marshal(req)
if err != nil {
return errutil.Wrap("marshal "+operation+" request", err)
}
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
if err != nil {
return errutil.Wrap("build "+operation+" request", err)
}
httpReq.Header.Set("Content-Type", "application/json")
resp, err := client.Do(httpReq)
if err != nil {
return errutil.Wrap("send "+operation+" request", err)
}
resultErr := noContentResponseError(resp, operation)
closeErr := resp.Body.Close()
if resultErr == nil && closeErr != nil {
return errutil.Wrap("close "+operation+" response body", closeErr)
}
return resultErr
}
func noContentResponseError(resp *http.Response, operation string) error {
if resp.StatusCode < 300 {
return nil
}
return &AgentHTTPError{Operation: operation, StatusCode: resp.StatusCode}
}
func (c *HTTPAgentClient) nodeURL(node *corev1.Node, path string) (string, error) {
endpoint, err := nodeURLHost(node)
if err != nil {
return "", err
}
return fmt.Sprintf("http://%s%s", net.JoinHostPort(endpoint, fmt.Sprint(c.port)), path), nil
}
func nodeURLHost(node *corev1.Node) (string, error) {
return nodemeta.ExtractNodeInternalIP(node)
}