* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* openFuyao is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
* http://license.coscl.org.cn/MulanPSL2
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
*/
package warmupjob
import (
"context"
"encoding/json"
"errors"
"io"
"net/http"
"strings"
"testing"
"time"
sharedtypes "github.com/openfuyao/weight-dispatcher/pkg/types"
corev1 "k8s.io/api/core/v1"
)
type roundTripFunc func(*http.Request) (*http.Response, error)
func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
return f(req)
}
func TestHTTPAgentClientBuildsNodeURLFromInternalIP(t *testing.T) {
t.Parallel()
nodeObj := corev1.Node{
Status: corev1.NodeStatus{Addresses: []corev1.NodeAddress{
{Type: corev1.NodeHostName, Address: "node-a"},
{Type: corev1.NodeInternalIP, Address: "10.0.0.10"},
}},
}
client := NewHTTPAgentClient(18080, 0, nil)
got, err := client.nodeURL(&nodeObj, "/v1/warmups")
if err != nil {
t.Fatalf("nodeURL returned error: %v", err)
}
if want := "http://10.0.0.10:18080/v1/warmups"; got != want {
t.Fatalf("expected url %q, got %q", want, got)
}
if client.client.Timeout != 10*time.Second {
t.Fatalf("expected default timeout, got %s", client.client.Timeout)
}
}
func TestHTTPAgentClientBuildsNodeURLFromIPv6InternalIP(t *testing.T) {
t.Parallel()
nodeObj := corev1.Node{
Status: corev1.NodeStatus{Addresses: []corev1.NodeAddress{
{Type: corev1.NodeInternalIP, Address: "fd00::10"},
}},
}
client := NewHTTPAgentClient(18080, 0, nil)
got, err := client.nodeURL(&nodeObj, "/v1/warmups")
if err != nil {
t.Fatalf("nodeURL returned error: %v", err)
}
if want := "http://[fd00::10]:18080/v1/warmups"; got != want {
t.Fatalf("expected url %q, got %q", want, got)
}
}
func TestNodeURLHostFallsBackToAnyAddressAndReportsEmptyNode(t *testing.T) {
t.Parallel()
host, err := nodeURLHost(&corev1.Node{
Status: corev1.NodeStatus{Addresses: []corev1.NodeAddress{{Type: corev1.NodeHostName, Address: "node-a"}}},
})
if err != nil {
t.Fatalf("nodeURLHost returned error: %v", err)
}
if host != "node-a" {
t.Fatalf("expected host fallback, got %q", host)
}
if _, err := nodeURLHost(&corev1.Node{}); err == nil {
t.Fatalf("expected empty node address error")
}
}
func TestPostJSONEncodesRequestAndDecodesResponse(t *testing.T) {
t.Parallel()
httpClient := &http.Client{Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) {
if req.Method != http.MethodPost {
t.Fatalf("expected POST, got %s", req.Method)
}
if req.Header.Get("Content-Type") != "application/json" {
t.Fatalf("expected json content type, got %q", req.Header.Get("Content-Type"))
}
var input sharedtypes.SubmitWarmupRequest
if err := json.NewDecoder(req.Body).Decode(&input); err != nil {
t.Fatalf("decode request: %v", err)
}
if input.Plan.TaskID != "task-a" {
t.Fatalf("expected task-a request, got %#v", input.Plan)
}
return jsonResponse(http.StatusOK, sharedtypes.TaskHandle{TaskID: "task-a", NodeName: "node-a"}), nil
})}
got, err := postJSON[sharedtypes.SubmitWarmupRequest, sharedtypes.TaskHandle](
context.Background(),
httpClient,
"http://node-a/v1/warmups",
sharedtypes.SubmitWarmupRequest{Plan: sharedtypes.WarmupExecutionPlan{TaskID: "task-a"}},
"submit warmup",
)
if err != nil {
t.Fatalf("postJSON returned error: %v", err)
}
if got.TaskID != "task-a" || got.NodeName != "node-a" {
t.Fatalf("unexpected response: %#v", got)
}
}
func TestHTTPHelpersReturnStatusAndDecodeErrors(t *testing.T) {
t.Parallel()
statusClient := &http.Client{Transport: roundTripFunc(func(*http.Request) (*http.Response, error) {
return jsonResponse(http.StatusConflict, map[string]string{"error": "duplicate"}), nil
})}
_, err := getJSON[sharedtypes.TaskStatus](context.Background(), statusClient, "http://node-a/status", "get warmup status")
if !IsAgentHTTPStatus(err, http.StatusConflict) {
t.Fatalf("expected status conflict error, got %v", err)
}
badJSONClient := &http.Client{Transport: roundTripFunc(func(*http.Request) (*http.Response, error) {
return &http.Response{StatusCode: http.StatusOK, Body: io.NopCloser(strings.NewReader("{"))}, nil
})}
if _, err := getJSON[sharedtypes.TaskStatus](context.Background(), badJSONClient, "http://node-a/status", "get warmup status"); err == nil {
t.Fatalf("expected decode error")
}
}
func TestPostNoContentHandlesSuccessStatusAndTransportError(t *testing.T) {
t.Parallel()
okClient := &http.Client{Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) {
if req.Method != http.MethodPost {
t.Fatalf("expected POST, got %s", req.Method)
}
return &http.Response{StatusCode: http.StatusNoContent, Body: http.NoBody}, nil
})}
if err := postNoContent(context.Background(), okClient, "http://node-a/complete", sharedtypes.CompleteCollectiveRequest{}, "complete collective"); err != nil {
t.Fatalf("postNoContent returned error: %v", err)
}
transportErr := errors.New("network down")
failClient := &http.Client{Transport: roundTripFunc(func(*http.Request) (*http.Response, error) {
return nil, transportErr
})}
if err := postNoContent(context.Background(), failClient, "http://node-a/complete", sharedtypes.CompleteCollectiveRequest{}, "complete collective"); err == nil || !strings.Contains(err.Error(), "network down") {
t.Fatalf("expected wrapped transport error, got %v", err)
}
}
func TestHTTPAgentClientMethodsUseExpectedPaths(t *testing.T) {
t.Parallel()
seenPaths := make(map[string]int)
client := NewHTTPAgentClient(18080, time.Second, nil)
client.client.Transport = roundTripFunc(func(req *http.Request) (*http.Response, error) {
seenPaths[req.URL.Path]++
switch req.URL.Path {
case "/v1/warmups":
return jsonResponse(http.StatusOK, sharedtypes.TaskHandle{TaskID: "task-a"}), nil
case "/v1/manifests:build":
return jsonResponse(http.StatusOK, sharedtypes.BuildManifestResponse{}), nil
case "/v1/warmups/task-a":
return jsonResponse(http.StatusOK, sharedtypes.TaskStatus{TaskID: "task-a", Phase: sharedtypes.WarmupPhaseRunning}), nil
case "/v1/collectives/open":
return jsonResponse(http.StatusOK, sharedtypes.OpenCollectiveResponse{TaskID: "task-a"}), nil
case "/v1/collectives/step":
return jsonResponse(http.StatusOK, sharedtypes.CollectiveStepResponse{TaskID: "task-a"}), nil
case "/v1/collectives/complete":
return &http.Response{StatusCode: http.StatusNoContent, Body: http.NoBody}, nil
default:
t.Fatalf("unexpected path %s", req.URL.Path)
return nil, nil
}
})
nodeObj := corev1.Node{Status: corev1.NodeStatus{Addresses: []corev1.NodeAddress{{Type: corev1.NodeInternalIP, Address: "10.0.0.10"}}}}
if _, err := client.SubmitWarmup(context.Background(), nodeObj, sharedtypes.SubmitWarmupRequest{Plan: sharedtypes.WarmupExecutionPlan{TaskID: "task-a"}}); err != nil {
t.Fatalf("SubmitWarmup: %v", err)
}
if _, err := client.BuildManifest(context.Background(), nodeObj, sharedtypes.BuildManifestRequest{ArtifactKey: "artifact-a"}); err != nil {
t.Fatalf("BuildManifest: %v", err)
}
if _, err := client.GetWarmupTaskStatus(context.Background(), nodeObj, sharedtypes.GetWarmupTaskStatusRequest{TaskID: "task-a"}); err != nil {
t.Fatalf("GetWarmupTaskStatus: %v", err)
}
if _, err := client.OpenCollective(context.Background(), nodeObj, sharedtypes.OpenCollectiveRequest{TaskID: "task-a"}); err != nil {
t.Fatalf("OpenCollective: %v", err)
}
if _, err := client.StepCollective(context.Background(), nodeObj, sharedtypes.CollectiveStepRequest{TaskID: "task-a"}); err != nil {
t.Fatalf("StepCollective: %v", err)
}
if err := client.CompleteCollective(context.Background(), nodeObj, sharedtypes.CompleteCollectiveRequest{TaskID: "task-a"}); err != nil {
t.Fatalf("CompleteCollective: %v", err)
}
if len(seenPaths) != 6 {
t.Fatalf("expected all node-agent paths to be called, got %#v", seenPaths)
}
}
func jsonResponse(status int, payload any) *http.Response {
body, err := json.Marshal(payload)
if err != nil {
panic(err)
}
return &http.Response{
StatusCode: status,
Body: io.NopCloser(strings.NewReader(string(body))),
Header: make(http.Header),
}
}