/*
 * Copyright (c) 2026 Huawei Technologies Co., Ltd.
 * openFuyao is licensed under Mulan PSL v2.
 * You can use this software according to the terms and conditions of the Mulan PSL v2.
 * You may obtain a copy of Mulan PSL v2 at:
 *          http://license.coscl.org.cn/MulanPSL2
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
 * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
 * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
 * See the Mulan PSL v2 for more details.
 */

package rdma

import (
	"bytes"
	"context"
	"encoding/json"
	"fmt"
	"net/http"
	"time"

	"github.com/openfuyao/weight-dispatcher/pkg/internal/errutil"
	sharedtypes "github.com/openfuyao/weight-dispatcher/pkg/types"
)

func waitForPeerFanoutSourceHalfReady(
	ctx context.Context,
	client ChunkClient,
	spec sharedtypes.TransferSpec,
	peer sharedtypes.CollectivePeerPlan,
	ranges []sharedtypes.ByteRange,
) error {
	if !isDirectoryPerFileFanoutSubTask(spec) || peer.Endpoint == "" || peer.StagingPath == "" {
		return nil
	}
	relativePath := ""
	for _, rng := range ranges {
		if rng.RelativePath != "" {
			relativePath = rng.RelativePath
			break
		}
	}
	if relativePath == "" {
		relativePath = fanoutSubTaskRelativePath(spec)
	}
	if relativePath == "" {
		return nil
	}
	markerRelativePath := fanoutStagingReadyMarkerRelativePath(relativePath)
	timeout := fanoutPeerWaitTimeout(spec, ranges)
	deadline := time.Now().Add(timeout)
	poll := 50 * time.Millisecond
	ticker := time.NewTicker(poll)
	defer ticker.Stop()
	for {
		size, err := client.Stat(ctx, peer.Endpoint, peer.StagingPath, markerRelativePath)
		if err == nil && size > 0 {
			return nil
		}
		if time.Now().After(deadline) {
			if err != nil {
				return fmt.Errorf("timeout waiting for peer source-half ready marker %s on %s: %w", markerRelativePath, peer.NodeName, err)
			}
			return fmt.Errorf("timeout waiting for peer source-half ready marker %s on %s", markerRelativePath, peer.NodeName)
		}
		if err := waitForNextPoll(ctx, ticker); err != nil {
			return err
		}
	}
}

func waitForPeerFanoutFileDone(
	ctx context.Context,
	client ChunkClient,
	spec sharedtypes.TransferSpec,
	peer sharedtypes.CollectivePeerPlan,
	relativePath string,
) error {
	if !isDirectoryPerFileFanoutSubTask(spec) || peer.Endpoint == "" || peer.StagingPath == "" || relativePath == "" {
		return nil
	}
	markerRelativePath := fanoutStagingDoneMarkerRelativePath(relativePath)
	timeout := fanoutPeerWaitTimeout(spec, []sharedtypes.ByteRange{{RelativePath: relativePath, Start: 0, End: 1}})
	deadline := time.Now().Add(timeout)
	poll := 50 * time.Millisecond
	ticker := time.NewTicker(poll)
	defer ticker.Stop()
	for {
		size, err := client.Stat(ctx, peer.Endpoint, peer.StagingPath, markerRelativePath)
		if err == nil && size > 0 {
			return nil
		}
		if time.Now().After(deadline) {
			if err != nil {
				return fmt.Errorf("timeout waiting for peer file-done marker %s on %s: %w", markerRelativePath, peer.NodeName, err)
			}
			return fmt.Errorf("timeout waiting for peer file-done marker %s on %s", markerRelativePath, peer.NodeName)
		}
		if err := waitForNextPoll(ctx, ticker); err != nil {
			return err
		}
	}
}

func waitForPreviousDirectoryFanoutFile(ctx context.Context, client ChunkClient, spec sharedtypes.TransferSpec, previous sharedtypes.ArtifactFile) error {
	if spec.TransferMode != sharedtypes.TransferModePartialPullAllGather || spec.CollectiveSpec.Ring == nil {
		return nil
	}
	if spec.CollectiveSpec.Ring.Rank != 0 || previous.RelativePath == "" {
		return nil
	}
	var lastPeer sharedtypes.CollectivePeerPlan
	found := false
	for _, peer := range spec.CollectiveSpec.Peers {
		if peer.Rank == spec.CollectiveSpec.Ring.WorldSize-1 {
			lastPeer = peer
			found = true
			break
		}
	}
	if !found {
		return nil
	}
	return waitForPeerFanoutFileDone(ctx, client, spec, lastPeer, previous.RelativePath)
}

func waitForNPeerSourceTurn(
	ctx context.Context,
	client ChunkClient,
	spec sharedtypes.TransferSpec,
	np nPeerFanoutSpec,
) error {
	if !isDirectoryPerFileFanoutSubTask(spec) {
		return nil
	}
	selfRank := np.selfPeer.Rank
	if selfRank <= 0 {
		return nil
	}
	prevPeer, ok := np.peerByRank(selfRank - 1)
	if !ok {
		return fmt.Errorf("n-peer source turn missing predecessor rank %d", selfRank-1)
	}
	return waitForPeerFanoutSourceHalfReady(ctx, client, spec, prevPeer, np.selfPeer.OwnedRanges)
}

func (a *Adapter) waitForPeerRelayFileCoverage(ctx context.Context, spec sharedtypes.TransferSpec, peer sharedtypes.CollectivePeerPlan, payloads []sharedtypes.CollectiveChunkPayload) error {
	if peer.Endpoint == "" || peer.StagingPath == "" {
		return nil
	}
	required := requiredRelayFileSizes(payloads)
	if len(required) == 0 {
		return nil
	}
	return a.waitForPeerRequiredFileSizes(ctx, peer, required, fanoutPeerWaitTimeout(spec, payloadsToByteRanges(payloads)), "relay file")
}

func (a *Adapter) waitForPeerRangeCoverage(ctx context.Context, spec sharedtypes.TransferSpec, peer sharedtypes.CollectivePeerPlan, ranges []sharedtypes.ByteRange) error {
	if peer.Endpoint == "" || peer.StagingPath == "" {
		return nil
	}
	if err := waitForPeerFanoutSourceHalfReady(ctx, a.client, spec, peer, ranges); err != nil {
		return err
	}
	required := requiredRangeFileSizes(ranges)
	if len(required) == 0 {
		return nil
	}
	return a.waitForPeerRequiredFileSizes(ctx, peer, required, fanoutPeerWaitTimeout(spec, ranges), "file")
}

func (a *Adapter) waitForPeerRequiredFileSizes(
	ctx context.Context,
	peer sharedtypes.CollectivePeerPlan,
	required map[string]int64,
	timeout time.Duration,
	fileKind string,
) error {
	deadline := time.Now().Add(timeout)
	ticker := time.NewTicker(50 * time.Millisecond)
	defer ticker.Stop()
	for relativePath, minSize := range required {
		if err := a.waitForPeerFileSize(ctx, peerFileWait{
			peer:         peer,
			relativePath: relativePath,
			minSize:      minSize,
			deadline:     deadline,
			ticker:       ticker,
			fileKind:     fileKind,
		}); err != nil {
			return err
		}
	}
	return nil
}

type peerFileWait struct {
	peer         sharedtypes.CollectivePeerPlan
	relativePath string
	minSize      int64
	deadline     time.Time
	ticker       *time.Ticker
	fileKind     string
}

func (a *Adapter) waitForPeerFileSize(ctx context.Context, wait peerFileWait) error {
	for {
		size, err := a.client.Stat(ctx, wait.peer.Endpoint, wait.peer.StagingPath, wait.relativePath)
		if err == nil && size >= wait.minSize {
			return nil
		}
		if time.Now().After(wait.deadline) {
			if err != nil {
				return fmt.Errorf("timeout waiting for peer %s %s on %s: %w", wait.fileKind, wait.relativePath, wait.peer.NodeName, err)
			}
			return fmt.Errorf("timeout waiting for peer %s %s on %s to reach %d bytes, current=%d", wait.fileKind, wait.relativePath, wait.peer.NodeName, wait.minSize, size)
		}
		if err := waitForNextPoll(ctx, wait.ticker); err != nil {
			return err
		}
	}
}

func fanoutPeerWaitTimeout(spec sharedtypes.TransferSpec, ranges []sharedtypes.ByteRange) time.Duration {
	timeout := collectiveTimeout(spec)
	if !isDirectoryPerFileFanoutSubTask(spec) {
		return timeout
	}
	totalBytes := int64(0)
	for _, rng := range ranges {
		if rng.End > rng.Start {
			totalBytes += rng.End - rng.Start
		}
	}
	if totalBytes <= 0 {
		return timeout
	}
	const conservativeBytesPerSecond = int64(256 * 1024 * 1024)
	estimatedSeconds := (totalBytes + conservativeBytesPerSecond - 1) / conservativeBytesPerSecond
	derived := time.Duration(estimatedSeconds)*time.Second + 45*time.Second
	if derived < timeout {
		return timeout
	}
	if derived > 15*time.Minute {
		return 15 * time.Minute
	}
	return derived
}

func (a *Adapter) waitCollectivePayloads(ctx context.Context, spec sharedtypes.TransferSpec, sessionID string, iteration int32, expectedRanges []sharedtypes.ByteRange, timeout time.Duration) ([]sharedtypes.CollectiveChunkPayload, error) {
	expectedCount := countValidRanges(expectedRanges)
	deadline := time.Now().Add(timeout)
	prevEndpoint, err := normalizeAgentEndpoint(spec.CollectiveSpec.Ring.PrevEndpoint)
	if err != nil {
		a.logger.Error("normalize previous collective endpoint failed", "taskID", spec.TaskID, "iteration", iteration, "err", err)
		return nil, errutil.Wrap("normalize previous collective endpoint", err)
	}
	ticker := time.NewTicker(100 * time.Millisecond)
	defer ticker.Stop()
	for {
		output, err := a.listCollectiveChunks(ctx, prevEndpoint, collectiveListRequest(spec.TaskID, sessionID, iteration), "collective chunk list")
		if err != nil {
			a.logger.Error("list collective payloads failed", "taskID", spec.TaskID, "iteration", iteration, "expectedCount", expectedCount, "err", err)
			return nil, err
		}
		if len(output.Chunks) >= expectedCount {
			return output.Chunks, nil
		}
		if time.Now().After(deadline) {
			err := fmt.Errorf("timed out waiting for collective iteration %d: expected at least %d chunks, got %d", iteration, expectedCount, len(output.Chunks))
			a.logger.Error("wait collective payloads timed out", "taskID", spec.TaskID, "iteration", iteration, "expectedCount", expectedCount, "chunkCount", len(output.Chunks), "err", err)
			return nil, err
		}
		if err := waitForNextPoll(ctx, ticker); err != nil {
			a.logger.Error("wait collective payload poll canceled", "taskID", spec.TaskID, "iteration", iteration, "err", err)
			return nil, err
		}
	}
}

func waitForNextPoll(ctx context.Context, ticker *time.Ticker) error {
	if ctx == nil {
		return fmt.Errorf("wait for next poll requires non-nil context")
	}
	select {
	case <-ctx.Done():
		return errutil.Wrap("wait for next poll", ctx.Err())
	case <-ticker.C:
		return nil
	}
}

func (a *Adapter) waitCollectivePayloadMetadata(ctx context.Context, endpoint, taskID, sessionID string, iteration int32, timeout time.Duration) ([]sharedtypes.CollectiveChunkPayload, error) {
	deadline := time.Now().Add(timeout)
	normalizedEndpoint, err := normalizeAgentEndpoint(endpoint)
	if err != nil {
		a.logger.Error("normalize relay collective endpoint failed", "taskID", taskID, "iteration", iteration, "endpoint", endpoint, "err", err)
		return nil, errutil.Wrap("normalize relay collective endpoint", err)
	}
	ticker := time.NewTicker(50 * time.Millisecond)
	defer ticker.Stop()
	for {
		output, err := a.listCollectiveChunks(ctx, normalizedEndpoint, collectiveListRequest(taskID, sessionID, iteration), "relay collective chunk list")
		if err != nil {
			a.logger.Error("list relay collective payload metadata failed", "taskID", taskID, "iteration", iteration, "endpoint", normalizedEndpoint, "err", err)
			return nil, err
		}
		if len(output.Chunks) > 0 {
			chunks, complete, err := a.handleRelayCollectiveMetadataPoll(ctx, taskID, iteration, deadline, output, ticker)
			if err != nil {
				return nil, err
			}
			if complete {
				return chunks, nil
			}
			continue
		}
		if time.Now().After(deadline) {
			err := fmt.Errorf("timed out waiting for relay collective iteration %d", iteration)
			a.logger.Error("wait relay collective payload metadata timed out", "taskID", taskID, "iteration", iteration, "endpoint", normalizedEndpoint, "err", err)
			return nil, err
		}
		if err := waitForNextPoll(ctx, ticker); err != nil {
			a.logger.Error("wait relay collective payload metadata poll canceled", "taskID", taskID, "iteration", iteration, "endpoint", normalizedEndpoint, "err", err)
			return nil, err
		}
	}
}

func collectiveListRequest(taskID, sessionID string, iteration int32) sharedtypes.ListCollectiveChunksRequest {
	return sharedtypes.ListCollectiveChunksRequest{
		TaskID:    taskID,
		SessionID: sessionID,
		Iteration: iteration,
	}
}

func (a *Adapter) listCollectiveChunks(
	ctx context.Context,
	endpoint string,
	request sharedtypes.ListCollectiveChunksRequest,
	action string,
) (sharedtypes.ListCollectiveChunksResponse, error) {
	body, err := json.Marshal(request)
	if err != nil {
		return sharedtypes.ListCollectiveChunksResponse{}, errutil.Wrap(fmt.Sprintf("marshal %s request", action), err)
	}
	req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint+"/v1/collectives/chunks/list", bytes.NewReader(body))
	if err != nil {
		return sharedtypes.ListCollectiveChunksResponse{}, errutil.Wrap(fmt.Sprintf("build %s request", action), err)
	}
	req.Header.Set("Content-Type", "application/json")
	resp, err := a.httpClient.Do(req)
	if err != nil {
		return sharedtypes.ListCollectiveChunksResponse{}, errutil.Wrap(fmt.Sprintf("post %s request", action), err)
	}
	if resp.StatusCode >= 300 {
		if closeErr := resp.Body.Close(); closeErr != nil {
			return sharedtypes.ListCollectiveChunksResponse{}, errutil.Wrap(fmt.Sprintf("close %s response body", action), closeErr)
		}
		return sharedtypes.ListCollectiveChunksResponse{}, fmt.Errorf("list collective chunks failed with status %d", resp.StatusCode)
	}
	var output sharedtypes.ListCollectiveChunksResponse
	err = json.NewDecoder(resp.Body).Decode(&output)
	if closeErr := resp.Body.Close(); closeErr != nil && err == nil {
		err = errutil.Wrap(fmt.Sprintf("close %s response body", action), closeErr)
	}
	if err != nil {
		return sharedtypes.ListCollectiveChunksResponse{}, errutil.Wrap(fmt.Sprintf("decode %s response", action), err)
	}
	return output, nil
}

func (a *Adapter) handleRelayCollectiveMetadataPoll(
	ctx context.Context,
	taskID string,
	iteration int32,
	deadline time.Time,
	output sharedtypes.ListCollectiveChunksResponse,
	ticker *time.Ticker,
) ([]sharedtypes.CollectiveChunkPayload, bool, error) {
	if output.ExpectedChunks <= 0 || len(output.Chunks) >= int(output.ExpectedChunks) {
		return output.Chunks, true, nil
	}
	if time.Now().After(deadline) {
		err := fmt.Errorf(
			"timed out waiting for complete relay collective iteration %d: have %d chunks, expect %d",
			iteration,
			len(output.Chunks),
			output.ExpectedChunks,
		)
		a.logger.Error("wait relay collective payload metadata incomplete", "taskID", taskID, "iteration", iteration, "chunkCount", len(output.Chunks), "expectedChunks", output.ExpectedChunks, "err", err)
		return nil, false, err
	}
	if err := waitForNextPoll(ctx, ticker); err != nil {
		a.logger.Error("wait relay collective payload metadata retry canceled", "taskID", taskID, "iteration", iteration, "err", err)
		return nil, false, err
	}
	return nil, false, nil
}