* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* openFuyao is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
* http://license.coscl.org.cn/MulanPSL2
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
*/
package rdma
import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
"time"
"github.com/openfuyao/weight-dispatcher/pkg/internal/errutil"
sharedtypes "github.com/openfuyao/weight-dispatcher/pkg/types"
)
func waitForPeerFanoutSourceHalfReady(
ctx context.Context,
client ChunkClient,
spec sharedtypes.TransferSpec,
peer sharedtypes.CollectivePeerPlan,
ranges []sharedtypes.ByteRange,
) error {
if !isDirectoryPerFileFanoutSubTask(spec) || peer.Endpoint == "" || peer.StagingPath == "" {
return nil
}
relativePath := ""
for _, rng := range ranges {
if rng.RelativePath != "" {
relativePath = rng.RelativePath
break
}
}
if relativePath == "" {
relativePath = fanoutSubTaskRelativePath(spec)
}
if relativePath == "" {
return nil
}
markerRelativePath := fanoutStagingReadyMarkerRelativePath(relativePath)
timeout := fanoutPeerWaitTimeout(spec, ranges)
deadline := time.Now().Add(timeout)
poll := 50 * time.Millisecond
ticker := time.NewTicker(poll)
defer ticker.Stop()
for {
size, err := client.Stat(ctx, peer.Endpoint, peer.StagingPath, markerRelativePath)
if err == nil && size > 0 {
return nil
}
if time.Now().After(deadline) {
if err != nil {
return fmt.Errorf("timeout waiting for peer source-half ready marker %s on %s: %w", markerRelativePath, peer.NodeName, err)
}
return fmt.Errorf("timeout waiting for peer source-half ready marker %s on %s", markerRelativePath, peer.NodeName)
}
if err := waitForNextPoll(ctx, ticker); err != nil {
return err
}
}
}
func waitForPeerFanoutFileDone(
ctx context.Context,
client ChunkClient,
spec sharedtypes.TransferSpec,
peer sharedtypes.CollectivePeerPlan,
relativePath string,
) error {
if !isDirectoryPerFileFanoutSubTask(spec) || peer.Endpoint == "" || peer.StagingPath == "" || relativePath == "" {
return nil
}
markerRelativePath := fanoutStagingDoneMarkerRelativePath(relativePath)
timeout := fanoutPeerWaitTimeout(spec, []sharedtypes.ByteRange{{RelativePath: relativePath, Start: 0, End: 1}})
deadline := time.Now().Add(timeout)
poll := 50 * time.Millisecond
ticker := time.NewTicker(poll)
defer ticker.Stop()
for {
size, err := client.Stat(ctx, peer.Endpoint, peer.StagingPath, markerRelativePath)
if err == nil && size > 0 {
return nil
}
if time.Now().After(deadline) {
if err != nil {
return fmt.Errorf("timeout waiting for peer file-done marker %s on %s: %w", markerRelativePath, peer.NodeName, err)
}
return fmt.Errorf("timeout waiting for peer file-done marker %s on %s", markerRelativePath, peer.NodeName)
}
if err := waitForNextPoll(ctx, ticker); err != nil {
return err
}
}
}
func waitForPreviousDirectoryFanoutFile(ctx context.Context, client ChunkClient, spec sharedtypes.TransferSpec, previous sharedtypes.ArtifactFile) error {
if spec.TransferMode != sharedtypes.TransferModePartialPullAllGather || spec.CollectiveSpec.Ring == nil {
return nil
}
if spec.CollectiveSpec.Ring.Rank != 0 || previous.RelativePath == "" {
return nil
}
var lastPeer sharedtypes.CollectivePeerPlan
found := false
for _, peer := range spec.CollectiveSpec.Peers {
if peer.Rank == spec.CollectiveSpec.Ring.WorldSize-1 {
lastPeer = peer
found = true
break
}
}
if !found {
return nil
}
return waitForPeerFanoutFileDone(ctx, client, spec, lastPeer, previous.RelativePath)
}
func waitForNPeerSourceTurn(
ctx context.Context,
client ChunkClient,
spec sharedtypes.TransferSpec,
np nPeerFanoutSpec,
) error {
if !isDirectoryPerFileFanoutSubTask(spec) {
return nil
}
selfRank := np.selfPeer.Rank
if selfRank <= 0 {
return nil
}
prevPeer, ok := np.peerByRank(selfRank - 1)
if !ok {
return fmt.Errorf("n-peer source turn missing predecessor rank %d", selfRank-1)
}
return waitForPeerFanoutSourceHalfReady(ctx, client, spec, prevPeer, np.selfPeer.OwnedRanges)
}
func (a *Adapter) waitForPeerRelayFileCoverage(ctx context.Context, spec sharedtypes.TransferSpec, peer sharedtypes.CollectivePeerPlan, payloads []sharedtypes.CollectiveChunkPayload) error {
if peer.Endpoint == "" || peer.StagingPath == "" {
return nil
}
required := requiredRelayFileSizes(payloads)
if len(required) == 0 {
return nil
}
return a.waitForPeerRequiredFileSizes(ctx, peer, required, fanoutPeerWaitTimeout(spec, payloadsToByteRanges(payloads)), "relay file")
}
func (a *Adapter) waitForPeerRangeCoverage(ctx context.Context, spec sharedtypes.TransferSpec, peer sharedtypes.CollectivePeerPlan, ranges []sharedtypes.ByteRange) error {
if peer.Endpoint == "" || peer.StagingPath == "" {
return nil
}
if err := waitForPeerFanoutSourceHalfReady(ctx, a.client, spec, peer, ranges); err != nil {
return err
}
required := requiredRangeFileSizes(ranges)
if len(required) == 0 {
return nil
}
return a.waitForPeerRequiredFileSizes(ctx, peer, required, fanoutPeerWaitTimeout(spec, ranges), "file")
}
func (a *Adapter) waitForPeerRequiredFileSizes(
ctx context.Context,
peer sharedtypes.CollectivePeerPlan,
required map[string]int64,
timeout time.Duration,
fileKind string,
) error {
deadline := time.Now().Add(timeout)
ticker := time.NewTicker(50 * time.Millisecond)
defer ticker.Stop()
for relativePath, minSize := range required {
if err := a.waitForPeerFileSize(ctx, peerFileWait{
peer: peer,
relativePath: relativePath,
minSize: minSize,
deadline: deadline,
ticker: ticker,
fileKind: fileKind,
}); err != nil {
return err
}
}
return nil
}
type peerFileWait struct {
peer sharedtypes.CollectivePeerPlan
relativePath string
minSize int64
deadline time.Time
ticker *time.Ticker
fileKind string
}
func (a *Adapter) waitForPeerFileSize(ctx context.Context, wait peerFileWait) error {
for {
size, err := a.client.Stat(ctx, wait.peer.Endpoint, wait.peer.StagingPath, wait.relativePath)
if err == nil && size >= wait.minSize {
return nil
}
if time.Now().After(wait.deadline) {
if err != nil {
return fmt.Errorf("timeout waiting for peer %s %s on %s: %w", wait.fileKind, wait.relativePath, wait.peer.NodeName, err)
}
return fmt.Errorf("timeout waiting for peer %s %s on %s to reach %d bytes, current=%d", wait.fileKind, wait.relativePath, wait.peer.NodeName, wait.minSize, size)
}
if err := waitForNextPoll(ctx, wait.ticker); err != nil {
return err
}
}
}
func fanoutPeerWaitTimeout(spec sharedtypes.TransferSpec, ranges []sharedtypes.ByteRange) time.Duration {
timeout := collectiveTimeout(spec)
if !isDirectoryPerFileFanoutSubTask(spec) {
return timeout
}
totalBytes := int64(0)
for _, rng := range ranges {
if rng.End > rng.Start {
totalBytes += rng.End - rng.Start
}
}
if totalBytes <= 0 {
return timeout
}
const conservativeBytesPerSecond = int64(256 * 1024 * 1024)
estimatedSeconds := (totalBytes + conservativeBytesPerSecond - 1) / conservativeBytesPerSecond
derived := time.Duration(estimatedSeconds)*time.Second + 45*time.Second
if derived < timeout {
return timeout
}
if derived > 15*time.Minute {
return 15 * time.Minute
}
return derived
}
func (a *Adapter) waitCollectivePayloads(ctx context.Context, spec sharedtypes.TransferSpec, sessionID string, iteration int32, expectedRanges []sharedtypes.ByteRange, timeout time.Duration) ([]sharedtypes.CollectiveChunkPayload, error) {
expectedCount := countValidRanges(expectedRanges)
deadline := time.Now().Add(timeout)
prevEndpoint, err := normalizeAgentEndpoint(spec.CollectiveSpec.Ring.PrevEndpoint)
if err != nil {
a.logger.Error("normalize previous collective endpoint failed", "taskID", spec.TaskID, "iteration", iteration, "err", err)
return nil, errutil.Wrap("normalize previous collective endpoint", err)
}
ticker := time.NewTicker(100 * time.Millisecond)
defer ticker.Stop()
for {
output, err := a.listCollectiveChunks(ctx, prevEndpoint, collectiveListRequest(spec.TaskID, sessionID, iteration), "collective chunk list")
if err != nil {
a.logger.Error("list collective payloads failed", "taskID", spec.TaskID, "iteration", iteration, "expectedCount", expectedCount, "err", err)
return nil, err
}
if len(output.Chunks) >= expectedCount {
return output.Chunks, nil
}
if time.Now().After(deadline) {
err := fmt.Errorf("timed out waiting for collective iteration %d: expected at least %d chunks, got %d", iteration, expectedCount, len(output.Chunks))
a.logger.Error("wait collective payloads timed out", "taskID", spec.TaskID, "iteration", iteration, "expectedCount", expectedCount, "chunkCount", len(output.Chunks), "err", err)
return nil, err
}
if err := waitForNextPoll(ctx, ticker); err != nil {
a.logger.Error("wait collective payload poll canceled", "taskID", spec.TaskID, "iteration", iteration, "err", err)
return nil, err
}
}
}
func waitForNextPoll(ctx context.Context, ticker *time.Ticker) error {
if ctx == nil {
return fmt.Errorf("wait for next poll requires non-nil context")
}
select {
case <-ctx.Done():
return errutil.Wrap("wait for next poll", ctx.Err())
case <-ticker.C:
return nil
}
}
func (a *Adapter) waitCollectivePayloadMetadata(ctx context.Context, endpoint, taskID, sessionID string, iteration int32, timeout time.Duration) ([]sharedtypes.CollectiveChunkPayload, error) {
deadline := time.Now().Add(timeout)
normalizedEndpoint, err := normalizeAgentEndpoint(endpoint)
if err != nil {
a.logger.Error("normalize relay collective endpoint failed", "taskID", taskID, "iteration", iteration, "endpoint", endpoint, "err", err)
return nil, errutil.Wrap("normalize relay collective endpoint", err)
}
ticker := time.NewTicker(50 * time.Millisecond)
defer ticker.Stop()
for {
output, err := a.listCollectiveChunks(ctx, normalizedEndpoint, collectiveListRequest(taskID, sessionID, iteration), "relay collective chunk list")
if err != nil {
a.logger.Error("list relay collective payload metadata failed", "taskID", taskID, "iteration", iteration, "endpoint", normalizedEndpoint, "err", err)
return nil, err
}
if len(output.Chunks) > 0 {
chunks, complete, err := a.handleRelayCollectiveMetadataPoll(ctx, taskID, iteration, deadline, output, ticker)
if err != nil {
return nil, err
}
if complete {
return chunks, nil
}
continue
}
if time.Now().After(deadline) {
err := fmt.Errorf("timed out waiting for relay collective iteration %d", iteration)
a.logger.Error("wait relay collective payload metadata timed out", "taskID", taskID, "iteration", iteration, "endpoint", normalizedEndpoint, "err", err)
return nil, err
}
if err := waitForNextPoll(ctx, ticker); err != nil {
a.logger.Error("wait relay collective payload metadata poll canceled", "taskID", taskID, "iteration", iteration, "endpoint", normalizedEndpoint, "err", err)
return nil, err
}
}
}
func collectiveListRequest(taskID, sessionID string, iteration int32) sharedtypes.ListCollectiveChunksRequest {
return sharedtypes.ListCollectiveChunksRequest{
TaskID: taskID,
SessionID: sessionID,
Iteration: iteration,
}
}
func (a *Adapter) listCollectiveChunks(
ctx context.Context,
endpoint string,
request sharedtypes.ListCollectiveChunksRequest,
action string,
) (sharedtypes.ListCollectiveChunksResponse, error) {
body, err := json.Marshal(request)
if err != nil {
return sharedtypes.ListCollectiveChunksResponse{}, errutil.Wrap(fmt.Sprintf("marshal %s request", action), err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint+"/v1/collectives/chunks/list", bytes.NewReader(body))
if err != nil {
return sharedtypes.ListCollectiveChunksResponse{}, errutil.Wrap(fmt.Sprintf("build %s request", action), err)
}
req.Header.Set("Content-Type", "application/json")
resp, err := a.httpClient.Do(req)
if err != nil {
return sharedtypes.ListCollectiveChunksResponse{}, errutil.Wrap(fmt.Sprintf("post %s request", action), err)
}
if resp.StatusCode >= 300 {
if closeErr := resp.Body.Close(); closeErr != nil {
return sharedtypes.ListCollectiveChunksResponse{}, errutil.Wrap(fmt.Sprintf("close %s response body", action), closeErr)
}
return sharedtypes.ListCollectiveChunksResponse{}, fmt.Errorf("list collective chunks failed with status %d", resp.StatusCode)
}
var output sharedtypes.ListCollectiveChunksResponse
err = json.NewDecoder(resp.Body).Decode(&output)
if closeErr := resp.Body.Close(); closeErr != nil && err == nil {
err = errutil.Wrap(fmt.Sprintf("close %s response body", action), closeErr)
}
if err != nil {
return sharedtypes.ListCollectiveChunksResponse{}, errutil.Wrap(fmt.Sprintf("decode %s response", action), err)
}
return output, nil
}
func (a *Adapter) handleRelayCollectiveMetadataPoll(
ctx context.Context,
taskID string,
iteration int32,
deadline time.Time,
output sharedtypes.ListCollectiveChunksResponse,
ticker *time.Ticker,
) ([]sharedtypes.CollectiveChunkPayload, bool, error) {
if output.ExpectedChunks <= 0 || len(output.Chunks) >= int(output.ExpectedChunks) {
return output.Chunks, true, nil
}
if time.Now().After(deadline) {
err := fmt.Errorf(
"timed out waiting for complete relay collective iteration %d: have %d chunks, expect %d",
iteration,
len(output.Chunks),
output.ExpectedChunks,
)
a.logger.Error("wait relay collective payload metadata incomplete", "taskID", taskID, "iteration", iteration, "chunkCount", len(output.Chunks), "expectedChunks", output.ExpectedChunks, "err", err)
return nil, false, err
}
if err := waitForNextPoll(ctx, ticker); err != nil {
a.logger.Error("wait relay collective payload metadata retry canceled", "taskID", taskID, "iteration", iteration, "err", err)
return nil, false, err
}
return nil, false, nil
}