/*
 * Copyright (c) 2026 Huawei Technologies Co., Ltd.
 * openFuyao is licensed under Mulan PSL v2.
 * You can use this software according to the terms and conditions of the Mulan PSL v2.
 * You may obtain a copy of Mulan PSL v2 at:
 *          http://license.coscl.org.cn/MulanPSL2
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
 * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
 * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
 * See the Mulan PSL v2 for more details.
 */

package warmupjob

import (
	"context"
	"encoding/json"
	"errors"
	"io"
	"net/http"
	"strings"
	"testing"
	"time"

	sharedtypes "github.com/openfuyao/weight-dispatcher/pkg/types"
	corev1 "k8s.io/api/core/v1"
)

type roundTripFunc func(*http.Request) (*http.Response, error)

func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
	return f(req)
}

func TestHTTPAgentClientBuildsNodeURLFromInternalIP(t *testing.T) {
	t.Parallel()

	nodeObj := corev1.Node{
		Status: corev1.NodeStatus{Addresses: []corev1.NodeAddress{
			{Type: corev1.NodeHostName, Address: "node-a"},
			{Type: corev1.NodeInternalIP, Address: "10.0.0.10"},
		}},
	}
	client := NewHTTPAgentClient(18080, 0, nil)

	got, err := client.nodeURL(&nodeObj, "/v1/warmups")
	if err != nil {
		t.Fatalf("nodeURL returned error: %v", err)
	}
	if want := "http://10.0.0.10:18080/v1/warmups"; got != want {
		t.Fatalf("expected url %q, got %q", want, got)
	}
	if client.client.Timeout != 10*time.Second {
		t.Fatalf("expected default timeout, got %s", client.client.Timeout)
	}
}

func TestHTTPAgentClientBuildsNodeURLFromIPv6InternalIP(t *testing.T) {
	t.Parallel()

	nodeObj := corev1.Node{
		Status: corev1.NodeStatus{Addresses: []corev1.NodeAddress{
			{Type: corev1.NodeInternalIP, Address: "fd00::10"},
		}},
	}
	client := NewHTTPAgentClient(18080, 0, nil)

	got, err := client.nodeURL(&nodeObj, "/v1/warmups")
	if err != nil {
		t.Fatalf("nodeURL returned error: %v", err)
	}
	if want := "http://[fd00::10]:18080/v1/warmups"; got != want {
		t.Fatalf("expected url %q, got %q", want, got)
	}
}

func TestNodeURLHostFallsBackToAnyAddressAndReportsEmptyNode(t *testing.T) {
	t.Parallel()

	host, err := nodeURLHost(&corev1.Node{
		Status: corev1.NodeStatus{Addresses: []corev1.NodeAddress{{Type: corev1.NodeHostName, Address: "node-a"}}},
	})
	if err != nil {
		t.Fatalf("nodeURLHost returned error: %v", err)
	}
	if host != "node-a" {
		t.Fatalf("expected host fallback, got %q", host)
	}

	if _, err := nodeURLHost(&corev1.Node{}); err == nil {
		t.Fatalf("expected empty node address error")
	}
}

func TestPostJSONEncodesRequestAndDecodesResponse(t *testing.T) {
	t.Parallel()

	httpClient := &http.Client{Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) {
		if req.Method != http.MethodPost {
			t.Fatalf("expected POST, got %s", req.Method)
		}
		if req.Header.Get("Content-Type") != "application/json" {
			t.Fatalf("expected json content type, got %q", req.Header.Get("Content-Type"))
		}
		var input sharedtypes.SubmitWarmupRequest
		if err := json.NewDecoder(req.Body).Decode(&input); err != nil {
			t.Fatalf("decode request: %v", err)
		}
		if input.Plan.TaskID != "task-a" {
			t.Fatalf("expected task-a request, got %#v", input.Plan)
		}
		return jsonResponse(http.StatusOK, sharedtypes.TaskHandle{TaskID: "task-a", NodeName: "node-a"}), nil
	})}

	got, err := postJSON[sharedtypes.SubmitWarmupRequest, sharedtypes.TaskHandle](
		context.Background(),
		httpClient,
		"http://node-a/v1/warmups",
		sharedtypes.SubmitWarmupRequest{Plan: sharedtypes.WarmupExecutionPlan{TaskID: "task-a"}},
		"submit warmup",
	)
	if err != nil {
		t.Fatalf("postJSON returned error: %v", err)
	}
	if got.TaskID != "task-a" || got.NodeName != "node-a" {
		t.Fatalf("unexpected response: %#v", got)
	}
}

func TestHTTPHelpersReturnStatusAndDecodeErrors(t *testing.T) {
	t.Parallel()

	statusClient := &http.Client{Transport: roundTripFunc(func(*http.Request) (*http.Response, error) {
		return jsonResponse(http.StatusConflict, map[string]string{"error": "duplicate"}), nil
	})}
	_, err := getJSON[sharedtypes.TaskStatus](context.Background(), statusClient, "http://node-a/status", "get warmup status")
	if !IsAgentHTTPStatus(err, http.StatusConflict) {
		t.Fatalf("expected status conflict error, got %v", err)
	}

	badJSONClient := &http.Client{Transport: roundTripFunc(func(*http.Request) (*http.Response, error) {
		return &http.Response{StatusCode: http.StatusOK, Body: io.NopCloser(strings.NewReader("{"))}, nil
	})}
	if _, err := getJSON[sharedtypes.TaskStatus](context.Background(), badJSONClient, "http://node-a/status", "get warmup status"); err == nil {
		t.Fatalf("expected decode error")
	}
}

func TestPostNoContentHandlesSuccessStatusAndTransportError(t *testing.T) {
	t.Parallel()

	okClient := &http.Client{Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) {
		if req.Method != http.MethodPost {
			t.Fatalf("expected POST, got %s", req.Method)
		}
		return &http.Response{StatusCode: http.StatusNoContent, Body: http.NoBody}, nil
	})}
	if err := postNoContent(context.Background(), okClient, "http://node-a/complete", sharedtypes.CompleteCollectiveRequest{}, "complete collective"); err != nil {
		t.Fatalf("postNoContent returned error: %v", err)
	}

	transportErr := errors.New("network down")
	failClient := &http.Client{Transport: roundTripFunc(func(*http.Request) (*http.Response, error) {
		return nil, transportErr
	})}
	if err := postNoContent(context.Background(), failClient, "http://node-a/complete", sharedtypes.CompleteCollectiveRequest{}, "complete collective"); err == nil || !strings.Contains(err.Error(), "network down") {
		t.Fatalf("expected wrapped transport error, got %v", err)
	}
}

func TestHTTPAgentClientMethodsUseExpectedPaths(t *testing.T) {
	t.Parallel()

	seenPaths := make(map[string]int)
	client := NewHTTPAgentClient(18080, time.Second, nil)
	client.client.Transport = roundTripFunc(func(req *http.Request) (*http.Response, error) {
		seenPaths[req.URL.Path]++
		switch req.URL.Path {
		case "/v1/warmups":
			return jsonResponse(http.StatusOK, sharedtypes.TaskHandle{TaskID: "task-a"}), nil
		case "/v1/manifests:build":
			return jsonResponse(http.StatusOK, sharedtypes.BuildManifestResponse{}), nil
		case "/v1/warmups/task-a":
			return jsonResponse(http.StatusOK, sharedtypes.TaskStatus{TaskID: "task-a", Phase: sharedtypes.WarmupPhaseRunning}), nil
		case "/v1/collectives/open":
			return jsonResponse(http.StatusOK, sharedtypes.OpenCollectiveResponse{TaskID: "task-a"}), nil
		case "/v1/collectives/step":
			return jsonResponse(http.StatusOK, sharedtypes.CollectiveStepResponse{TaskID: "task-a"}), nil
		case "/v1/collectives/complete":
			return &http.Response{StatusCode: http.StatusNoContent, Body: http.NoBody}, nil
		default:
			t.Fatalf("unexpected path %s", req.URL.Path)
			return nil, nil
		}
	})
	nodeObj := corev1.Node{Status: corev1.NodeStatus{Addresses: []corev1.NodeAddress{{Type: corev1.NodeInternalIP, Address: "10.0.0.10"}}}}

	if _, err := client.SubmitWarmup(context.Background(), nodeObj, sharedtypes.SubmitWarmupRequest{Plan: sharedtypes.WarmupExecutionPlan{TaskID: "task-a"}}); err != nil {
		t.Fatalf("SubmitWarmup: %v", err)
	}
	if _, err := client.BuildManifest(context.Background(), nodeObj, sharedtypes.BuildManifestRequest{ArtifactKey: "artifact-a"}); err != nil {
		t.Fatalf("BuildManifest: %v", err)
	}
	if _, err := client.GetWarmupTaskStatus(context.Background(), nodeObj, sharedtypes.GetWarmupTaskStatusRequest{TaskID: "task-a"}); err != nil {
		t.Fatalf("GetWarmupTaskStatus: %v", err)
	}
	if _, err := client.OpenCollective(context.Background(), nodeObj, sharedtypes.OpenCollectiveRequest{TaskID: "task-a"}); err != nil {
		t.Fatalf("OpenCollective: %v", err)
	}
	if _, err := client.StepCollective(context.Background(), nodeObj, sharedtypes.CollectiveStepRequest{TaskID: "task-a"}); err != nil {
		t.Fatalf("StepCollective: %v", err)
	}
	if err := client.CompleteCollective(context.Background(), nodeObj, sharedtypes.CompleteCollectiveRequest{TaskID: "task-a"}); err != nil {
		t.Fatalf("CompleteCollective: %v", err)
	}
	if len(seenPaths) != 6 {
		t.Fatalf("expected all node-agent paths to be called, got %#v", seenPaths)
	}
}

func jsonResponse(status int, payload any) *http.Response {
	body, err := json.Marshal(payload)
	if err != nil {
		panic(err)
	}
	return &http.Response{
		StatusCode: status,
		Body:       io.NopCloser(strings.NewReader(string(body))),
		Header:     make(http.Header),
	}
}