/*
 * Copyright (c) 2026 Huawei Technologies Co., Ltd.
 * openFuyao is licensed under Mulan PSL v2.
 * You can use this software according to the terms and conditions of the Mulan PSL v2.
 * You may obtain a copy of Mulan PSL v2 at:
 *          http://license.coscl.org.cn/MulanPSL2
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
 * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
 * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
 * See the Mulan PSL v2 for more details.
 */

package warmupjob

import (
	"bytes"
	"context"
	"encoding/json"
	"errors"
	"fmt"
	"log/slog"
	"net"
	"net/http"
	"time"

	corev1 "k8s.io/api/core/v1"

	"github.com/openfuyao/weight-dispatcher/pkg/internal/errutil"
	nodemeta "github.com/openfuyao/weight-dispatcher/pkg/node"
	sharedtypes "github.com/openfuyao/weight-dispatcher/pkg/types"
)

// HTTPAgentClient 负责通过 HTTP 调用 node-agent。
type HTTPAgentClient struct {
	client *http.Client
	logger *slog.Logger
	port   int
}

type nodeAgentCallSpec struct {
	path               string
	operation          string
	buildURLMessage    string
	callFailureMessage string
	extraAttrs         []any
}

const (
	logAttrError = "error"
	logAttrNode  = "node"
	logAttrTask  = "taskID"
	logAttrURL   = "url"
)

// AgentHTTPError reports a non-success HTTP status from one node-agent call.
type AgentHTTPError struct {
	Operation  string
	StatusCode int
}

func (e *AgentHTTPError) Error() string {
	if e == nil {
		return ""
	}
	return fmt.Sprintf("%s failed with status %d", e.Operation, e.StatusCode)
}

// IsAgentHTTPStatus reports whether one error wraps the given node-agent status code.
func IsAgentHTTPStatus(err error, statusCode int) bool {
	var target *AgentHTTPError
	return errors.As(err, &target) && target.StatusCode == statusCode
}

// NewHTTPAgentClient creates the controller-side HTTP client for node-agent calls.
func NewHTTPAgentClient(port int, timeout time.Duration, logger *slog.Logger) *HTTPAgentClient {
	if timeout <= 0 {
		timeout = 10 * time.Second
	}
	if logger == nil {
		logger = slog.Default()
	}
	return &HTTPAgentClient{
		client: &http.Client{Timeout: timeout},
		logger: logger,
		port:   port,
	}
}

// SubmitWarmup sends one warmup execution plan to the target node-agent.
func (c *HTTPAgentClient) SubmitWarmup(ctx context.Context, node corev1.Node, req sharedtypes.SubmitWarmupRequest) (sharedtypes.TaskHandle, error) {
	return callNodeAgentPostJSON[sharedtypes.SubmitWarmupRequest, sharedtypes.TaskHandle](
		ctx,
		c,
		node,
		req,
		nodeAgentCallSpec{
			path:               "/v1/warmups",
			operation:          "submit warmup",
			buildURLMessage:    "构建 node-agent Warmup URL 失败",
			callFailureMessage: "向 node-agent 提交 Warmup 任务失败",
			extraAttrs:         []any{logAttrTask, req.Plan.TaskID},
		},
	)
}

// BuildManifest requests a logical manifest from the target node-agent.
func (c *HTTPAgentClient) BuildManifest(ctx context.Context, node corev1.Node, req sharedtypes.BuildManifestRequest) (sharedtypes.BuildManifestResponse, error) {
	return callNodeAgentPostJSON[sharedtypes.BuildManifestRequest, sharedtypes.BuildManifestResponse](
		ctx,
		c,
		node,
		req,
		nodeAgentCallSpec{
			path:               "/v1/manifests:build",
			operation:          "build manifest",
			buildURLMessage:    "构建 node-agent manifest URL 失败",
			callFailureMessage: "向 node-agent 构建逻辑 manifest 失败",
			extraAttrs:         []any{"artifactKey", req.ArtifactKey},
		},
	)
}

// GetWarmupTaskStatus fetches one warmup task status from the target node-agent.
func (c *HTTPAgentClient) GetWarmupTaskStatus(ctx context.Context, node corev1.Node, req sharedtypes.GetWarmupTaskStatusRequest) (sharedtypes.TaskStatus, error) {
	return callNodeAgentGetJSON[sharedtypes.TaskStatus](
		ctx,
		c,
		node,
		nodeAgentCallSpec{
			path:               "/v1/warmups/" + req.TaskID,
			operation:          "get warmup status",
			buildURLMessage:    "构建 node-agent 任务状态 URL 失败",
			callFailureMessage: "获取 node-agent 任务状态失败",
			extraAttrs:         []any{logAttrTask, req.TaskID},
		},
	)
}

// OpenCollective opens one collective session on the target node-agent.
func (c *HTTPAgentClient) OpenCollective(ctx context.Context, node corev1.Node, req sharedtypes.OpenCollectiveRequest) (sharedtypes.OpenCollectiveResponse, error) {
	return callNodeAgentPostJSON[sharedtypes.OpenCollectiveRequest, sharedtypes.OpenCollectiveResponse](
		ctx,
		c,
		node,
		req,
		nodeAgentCallSpec{
			path:               "/v1/collectives/open",
			operation:          "open collective",
			buildURLMessage:    "构建 node-agent collective open URL 失败",
			callFailureMessage: "打开 node-agent collective 会话失败",
			extraAttrs:         []any{logAttrTask, req.TaskID},
		},
	)
}

// StepCollective advances one collective iteration on the target node-agent.
func (c *HTTPAgentClient) StepCollective(ctx context.Context, node corev1.Node, req sharedtypes.CollectiveStepRequest) (sharedtypes.CollectiveStepResponse, error) {
	return callNodeAgentPostJSON[sharedtypes.CollectiveStepRequest, sharedtypes.CollectiveStepResponse](
		ctx,
		c,
		node,
		req,
		nodeAgentCallSpec{
			path:               "/v1/collectives/step",
			operation:          "step collective",
			buildURLMessage:    "构建 node-agent collective step URL 失败",
			callFailureMessage: "推进 node-agent collective 失败",
			extraAttrs:         []any{logAttrTask, req.TaskID},
		},
	)
}

// CompleteCollective closes one collective session on the target node-agent.
func (c *HTTPAgentClient) CompleteCollective(ctx context.Context, node corev1.Node, req sharedtypes.CompleteCollectiveRequest) error {
	return callNodeAgentPostNoContent(
		ctx,
		c,
		node,
		req,
		nodeAgentCallSpec{
			path:               "/v1/collectives/complete",
			operation:          "complete collective",
			buildURLMessage:    "构建 node-agent collective complete URL 失败",
			callFailureMessage: "关闭 node-agent collective 会话失败",
			extraAttrs:         []any{logAttrTask, req.TaskID},
		},
	)
}

func callNodeAgentPostJSON[Req, Resp any](
	ctx context.Context,
	client *HTTPAgentClient,
	node corev1.Node,
	req Req,
	spec nodeAgentCallSpec,
) (Resp, error) {
	url, err := client.nodeURL(&node, spec.path)
	if err != nil {
		client.logger.Error(spec.buildURLMessage, append([]any{logAttrNode, node.Name, logAttrError, err}, spec.extraAttrs...)...)
		return *new(Resp), errutil.Wrap("build "+spec.operation+" URL", err)
	}
	response, callErr := postJSON[Req, Resp](ctx, client.client, url, req, spec.operation)
	if callErr != nil {
		client.logger.Error(spec.callFailureMessage, append([]any{logAttrNode, node.Name, logAttrURL, url, logAttrError, callErr}, spec.extraAttrs...)...)
	}
	return response, callErr
}

func callNodeAgentGetJSON[Resp any](
	ctx context.Context,
	client *HTTPAgentClient,
	node corev1.Node,
	spec nodeAgentCallSpec,
) (Resp, error) {
	url, err := client.nodeURL(&node, spec.path)
	if err != nil {
		client.logger.Error(spec.buildURLMessage, append([]any{logAttrNode, node.Name, logAttrError, err}, spec.extraAttrs...)...)
		return *new(Resp), errutil.Wrap("build "+spec.operation+" URL", err)
	}
	response, callErr := getJSON[Resp](ctx, client.client, url, spec.operation)
	if callErr != nil {
		client.logger.Error(spec.callFailureMessage, append([]any{logAttrNode, node.Name, logAttrURL, url, logAttrError, callErr}, spec.extraAttrs...)...)
	}
	return response, callErr
}

func callNodeAgentPostNoContent[Req any](
	ctx context.Context,
	client *HTTPAgentClient,
	node corev1.Node,
	req Req,
	spec nodeAgentCallSpec,
) error {
	url, err := client.nodeURL(&node, spec.path)
	if err != nil {
		client.logger.Error(spec.buildURLMessage, append([]any{logAttrNode, node.Name, logAttrError, err}, spec.extraAttrs...)...)
		return errutil.Wrap("build "+spec.operation+" URL", err)
	}
	callErr := postNoContent(ctx, client.client, url, req, spec.operation)
	if callErr != nil {
		client.logger.Error(spec.callFailureMessage, append([]any{logAttrNode, node.Name, logAttrURL, url, logAttrError, callErr}, spec.extraAttrs...)...)
	}
	return callErr
}

func postJSON[Req, Resp any](ctx context.Context, client *http.Client, url string, req Req, operation string) (Resp, error) {

	body, err := json.Marshal(req)
	if err != nil {
		return *new(Resp), errutil.Wrap("marshal "+operation+" request", err)
	}
	httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
	if err != nil {
		return *new(Resp), errutil.Wrap("build "+operation+" request", err)
	}
	httpReq.Header.Set("Content-Type", "application/json")
	resp, err := client.Do(httpReq)
	if err != nil {
		return *new(Resp), errutil.Wrap("send "+operation+" request", err)
	}
	return decodeJSONResponse[Resp](resp, operation)
}

func getJSON[Resp any](ctx context.Context, client *http.Client, url, operation string) (Resp, error) {
	httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, url, http.NoBody)
	if err != nil {
		return *new(Resp), errutil.Wrap("build "+operation+" request", err)
	}
	resp, err := client.Do(httpReq)
	if err != nil {
		return *new(Resp), errutil.Wrap("send "+operation+" request", err)
	}

	return decodeJSONResponse[Resp](resp, operation)
}

func decodeJSONResponse[Resp any](resp *http.Response, operation string) (Resp, error) {
	var output Resp
	resultErr := decodeJSONResponseBody(resp, operation, &output)
	closeErr := resp.Body.Close()
	if resultErr == nil && closeErr != nil {
		resultErr = errutil.Wrap("close "+operation+" response body", closeErr)
	}
	if resultErr != nil {
		return *new(Resp), resultErr
	}
	return output, nil
}

func decodeJSONResponseBody[Resp any](resp *http.Response, operation string, output *Resp) error {
	if resp.StatusCode >= 300 {
		return &AgentHTTPError{Operation: operation, StatusCode: resp.StatusCode}
	}
	if err := json.NewDecoder(resp.Body).Decode(output); err != nil {
		return errutil.Wrap("decode "+operation+" response", err)
	}
	return nil
}

func postNoContent[Req any](ctx context.Context, client *http.Client, url string, req Req, operation string) error {
	body, err := json.Marshal(req)
	if err != nil {
		return errutil.Wrap("marshal "+operation+" request", err)
	}
	httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
	if err != nil {
		return errutil.Wrap("build "+operation+" request", err)
	}
	httpReq.Header.Set("Content-Type", "application/json")
	resp, err := client.Do(httpReq)
	if err != nil {
		return errutil.Wrap("send "+operation+" request", err)
	}

	resultErr := noContentResponseError(resp, operation)
	closeErr := resp.Body.Close()
	if resultErr == nil && closeErr != nil {
		return errutil.Wrap("close "+operation+" response body", closeErr)
	}
	return resultErr
}

func noContentResponseError(resp *http.Response, operation string) error {
	if resp.StatusCode < 300 {
		return nil
	}
	return &AgentHTTPError{Operation: operation, StatusCode: resp.StatusCode}
}

func (c *HTTPAgentClient) nodeURL(node *corev1.Node, path string) (string, error) {
	endpoint, err := nodeURLHost(node)
	if err != nil {
		return "", err
	}
	return fmt.Sprintf("http://%s%s", net.JoinHostPort(endpoint, fmt.Sprint(c.port)), path), nil
}

func nodeURLHost(node *corev1.Node) (string, error) {
	return nodemeta.ExtractNodeInternalIP(node)
}