* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* openFuyao is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
* http://license.coscl.org.cn/MulanPSL2
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
*/
package rdma
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"hash/crc32"
"io"
"log/slog"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"sync"
"testing"
"time"
sharedtypes "github.com/openfuyao/weight-dispatcher/pkg/types"
)
type fakeChunkClient struct {
stats map[string]int64
}
func (f fakeChunkClient) Stat(_ context.Context, endpoint, rootPath, relativePath string) (int64, error) {
key := fmt.Sprintf("%s|%s|%s", endpoint, rootPath, relativePath)
if size, ok := f.stats[key]; ok {
return size, nil
}
return 0, errors.New("not found")
}
func (f fakeChunkClient) ReadAt(_ context.Context, _, _, _ string, _, _ int64) ([]byte, error) {
return nil, errors.New("not implemented")
}
func TestHTTPChunkClientStatAndRawReadWithCRC(t *testing.T) {
t.Parallel()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/v1/exports/stat":
var req sharedtypes.StatExportRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
t.Fatalf("Decode stat request returned error: %v", err)
}
if req.RootPath != "/models" || req.RelativePath != "weights.bin" {
t.Fatalf("unexpected stat request: %#v", req)
}
_ = json.NewEncoder(w).Encode(sharedtypes.StatExportResponse{SizeBytes: 11})
case "/v1/exports/readat/raw":
var req sharedtypes.ReadChunkRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
t.Fatalf("Decode raw read request returned error: %v", err)
}
if req.Offset != 2 || req.Length != 5 {
t.Fatalf("unexpected read request: %#v", req)
}
w.Header().Set("X-Chunk-CRC32C", " crc ")
_, _ = w.Write([]byte("hello"))
default:
t.Fatalf("unexpected path %s", r.URL.Path)
}
}))
defer server.Close()
client := NewHTTPChunkClient(server.Client())
size, err := client.Stat(context.Background(), server.URL, "/models", "weights.bin")
if err != nil {
t.Fatalf("Stat returned error: %v", err)
}
if size != 11 {
t.Fatalf("expected size 11, got %d", size)
}
chunk, err := client.ReadAtWithCRC(context.Background(), server.URL, "/models", "weights.bin", 2, 5)
if err != nil {
t.Fatalf("ReadAtWithCRC returned error: %v", err)
}
if string(chunk.Payload) != "hello" || chunk.CRC32C != "crc" {
t.Fatalf("unexpected chunk: %#v", chunk)
}
}
func TestHTTPChunkClientFallsBackToLegacyJSON(t *testing.T) {
t.Parallel()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/v1/exports/readat/raw":
http.Error(w, "raw unsupported", http.StatusInternalServerError)
case "/v1/exports/readat":
_ = json.NewEncoder(w).Encode(sharedtypes.ReadChunkResponse{Data: []byte("legacy"), CRC32C: "digest"})
default:
t.Fatalf("unexpected path %s", r.URL.Path)
}
}))
defer server.Close()
chunk, err := NewHTTPChunkClient(server.Client()).ReadAtWithCRC(context.Background(), server.URL, "/models", "weights.bin", 0, 6)
if err != nil {
t.Fatalf("ReadAtWithCRC returned error: %v", err)
}
if string(chunk.Payload) != "legacy" || chunk.CRC32C != "digest" {
t.Fatalf("unexpected legacy chunk: %#v", chunk)
}
}
func TestHTTPChunkClientReportsRawReadFailures(t *testing.T) {
t.Parallel()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/exports/readat/raw" {
t.Fatalf("unexpected path %s", r.URL.Path)
}
http.Error(w, "missing chunk", http.StatusNotFound)
}))
defer server.Close()
_, err := NewHTTPChunkClient(server.Client()).ReadAtWithCRC(context.Background(), server.URL, "/models", "missing.bin", 0, 4)
if err == nil || !strings.Contains(err.Error(), "missing chunk") {
t.Fatalf("expected raw read failure with body, got %v", err)
}
}
func TestWaitHelpersCoverNoopAndReadyPaths(t *testing.T) {
t.Parallel()
spec := sharedtypes.TransferSpec{
TaskID: "wait-helper",
TransferMode: sharedtypes.TransferModePartialPullAllGather,
PreserveExisting: true,
LogicalManifest: sharedtypes.LogicalManifest{Files: []sharedtypes.ArtifactFile{{
RelativePath: "model.bin",
SizeBytes: 12,
}}},
CollectiveSpec: sharedtypes.CollectiveSpec{
Ring: &sharedtypes.RingPeerPlan{Rank: 1, WorldSize: 3},
},
}
peer := sharedtypes.CollectivePeerPlan{
NodeName: "rank0",
Endpoint: "http://rank0:18080",
StagingPath: "/staging",
}
ranges := []sharedtypes.ByteRange{{RelativePath: "model.bin", Start: 0, End: 12}}
marker := fanoutStagingReadyMarkerRelativePath("model.bin")
client := fakeChunkClient{stats: map[string]int64{
"http://rank0:18080|/staging|" + marker: 1,
"http://rank0:18080|/staging|model.bin": 12,
"http://rank0:18080|/staging|relay.bin": 9,
"http://rank0:18080|/staging|empty-marker": 0,
}}
adapter := NewAdapterWithOptions(client, nil, AdapterOptions{ForceTCPFallback: true})
if err := waitForPeerFanoutSourceHalfReady(context.Background(), client, spec, peer, ranges); err != nil {
t.Fatalf("waitForPeerFanoutSourceHalfReady returned error: %v", err)
}
if err := adapter.waitForPeerRangeCoverage(context.Background(), spec, peer, ranges); err != nil {
t.Fatalf("waitForPeerRangeCoverage returned error: %v", err)
}
payloads := []sharedtypes.CollectiveChunkPayload{{Chunk: sharedtypes.TransferredChunk{RelativePath: "relay.bin", Offset: 0, Size: 9}}}
if err := adapter.waitForPeerRelayFileCoverage(context.Background(), spec, peer, payloads); err != nil {
t.Fatalf("waitForPeerRelayFileCoverage returned error: %v", err)
}
if err := waitForPeerFanoutSourceHalfReady(context.Background(), client, spec, sharedtypes.CollectivePeerPlan{}, ranges); err != nil {
t.Fatalf("empty peer no-op returned error: %v", err)
}
if err := waitForNPeerSourceTurn(context.Background(), client, spec, nPeerFanoutSpec{selfPeer: sharedtypes.CollectivePeerPlan{Rank: 0}}); err != nil {
t.Fatalf("rank zero source turn no-op returned error: %v", err)
}
}
func TestFanoutHelperEdgeCases(t *testing.T) {
t.Parallel()
if _, ok := resolveNPeerFanoutSpec(sharedtypes.TransferSpec{}); ok {
t.Fatalf("expected empty spec to skip n-peer fanout")
}
invalidNPeer := sharedtypes.TransferSpec{
TransferMode: sharedtypes.TransferModePartialPullAllGather,
CollectiveSpec: sharedtypes.CollectiveSpec{
Ring: &sharedtypes.RingPeerPlan{Rank: 0, WorldSize: 3, SelfEndpoint: "self"},
Peers: []sharedtypes.CollectivePeerPlan{
{Rank: 0, NodeName: "rank0"},
{Rank: 1, NodeName: "rank1"},
},
},
}
if _, ok := resolveNPeerFanoutSpec(invalidNPeer); ok {
t.Fatalf("expected mismatched world size to skip n-peer fanout")
}
validNPeer := invalidNPeer
validNPeer.CollectiveSpec.Peers = []sharedtypes.CollectivePeerPlan{
{Rank: 0, NodeName: "rank0", Endpoint: "self"},
{Rank: 1, NodeName: "rank1", Endpoint: "peer1"},
{Rank: 2, NodeName: "rank2", Endpoint: "peer2"},
}
np, ok := resolveNPeerFanoutSpec(validNPeer)
if !ok {
t.Fatalf("expected valid n-peer spec")
}
if np.otherPeers[0].Rank != 2 || np.otherPeers[1].Rank != 1 {
t.Fatalf("expected peers sorted by ring owner step, got %#v", np.otherPeers)
}
if _, ok := np.peerByRank(9); ok {
t.Fatalf("unexpected peer for missing rank")
}
if got := ringOwnerStep(3, 0, 0); got != 3 {
t.Fatalf("expected self owner step to wrap to world size, got %d", got)
}
validSymmetric := sharedtypes.TransferSpec{
TransferMode: sharedtypes.TransferModePartialPullAllGather,
Parallelism: 99,
CollectiveSpec: sharedtypes.CollectiveSpec{
Ring: &sharedtypes.RingPeerPlan{Rank: 0, WorldSize: 2, SelfEndpoint: "self"},
Peers: []sharedtypes.CollectivePeerPlan{
{Rank: 0, NodeName: "self", OwnedRanges: []sharedtypes.ByteRange{{RelativePath: "model.bin", Start: 0, End: 8}}},
{Rank: 1, NodeName: "peer", Endpoint: "peer", OwnedRanges: []sharedtypes.ByteRange{{RelativePath: "model.bin", Start: 8, End: 16}}},
},
},
}
role, ok := resolveSymmetricFanoutRole(validSymmetric)
if !ok {
t.Fatalf("expected symmetric fanout role")
}
if derived, ok := derivedSymmetricFanoutSourceParallelism(role); !ok || derived < 1 {
t.Fatalf("expected derived symmetric parallelism, got %d ok=%v", derived, ok)
}
if got := boundedSymmetricFanoutSourceParallelism(99); got != maxSymmetricFanoutSourceParallelism {
t.Fatalf("expected bounded parallelism cap, got %d", got)
}
if got := clampConfiguredSymmetricFanoutSourceParallelism(-1); got != 1 {
t.Fatalf("expected configured clamp floor, got %d", got)
}
if got := clampConfiguredSymmetricFanoutSourceParallelism(99); got != 8 {
t.Fatalf("expected configured clamp ceiling, got %d", got)
}
}
func TestFanoutMarkersAndResultHelpers(t *testing.T) {
t.Parallel()
spec := sharedtypes.TransferSpec{
TaskID: "marker-task",
TransferMode: sharedtypes.TransferModePartialPullAllGather,
TargetTempPath: t.TempDir(),
PreserveExisting: true,
LogicalManifest: sharedtypes.LogicalManifest{Files: []sharedtypes.ArtifactFile{{
RelativePath: "nested/model.bin",
SizeBytes: 16,
}}},
}
if got := fanoutSubTaskRelativePath(spec); got != "nested/model.bin" {
t.Fatalf("unexpected fanout subtask relative path %q", got)
}
if err := markFanoutSourceHalfReady(spec); err != nil {
t.Fatalf("markFanoutSourceHalfReady returned error: %v", err)
}
if err := markFanoutFileDone(spec); err != nil {
t.Fatalf("markFanoutFileDone returned error: %v", err)
}
for _, relative := range []string{
fanoutStagingReadyMarkerRelativePath("nested/model.bin"),
fanoutStagingDoneMarkerRelativePath("nested/model.bin"),
} {
if _, err := os.Stat(filepath.Join(spec.TargetTempPath, relative)); err != nil {
t.Fatalf("expected marker %s: %v", relative, err)
}
}
relayRequired := requiredRelayFileSizes([]sharedtypes.CollectiveChunkPayload{
{Chunk: sharedtypes.TransferredChunk{RelativePath: "a.bin", Offset: 0, Size: 4}},
{Chunk: sharedtypes.TransferredChunk{RelativePath: "a.bin", Offset: 2, Size: 8}},
})
if relayRequired["a.bin"] != 10 {
t.Fatalf("expected relay required size 10, got %d", relayRequired["a.bin"])
}
rangeRequired := requiredRangeFileSizes([]sharedtypes.ByteRange{
{RelativePath: "b.bin", Start: 7, End: 7},
{RelativePath: "b.bin", Start: 3, End: 9},
})
if rangeRequired["b.bin"] != 9 {
t.Fatalf("expected range required size 9, got %d", rangeRequired["b.bin"])
}
merged := logCombinedFanoutResult(
slog.New(slog.NewTextHandler(io.Discard, nil)),
spec,
time.Now(),
fanoutTransferRun{result: sharedtypes.TransferResult{BytesTransferred: 3}, transportPath: sharedtypes.TransportPathRDMA},
fanoutTransferRun{result: sharedtypes.TransferResult{BytesTransferred: 4}, transportPath: sharedtypes.TransportPathRDMA},
)
if merged.BytesTransferred != 7 || mergeFanoutTransportPaths("", "") != sharedtypes.TransportPathRDMA {
t.Fatalf("unexpected merged fanout result: %#v", merged)
}
}
func TestFanoutWaitHelpersCoverDoneMarkerAndTimeouts(t *testing.T) {
t.Parallel()
spec := sharedtypes.TransferSpec{
TransferMode: sharedtypes.TransferModePartialPullAllGather,
TimeoutSeconds: 1,
PreserveExisting: true,
LogicalManifest: sharedtypes.LogicalManifest{Files: []sharedtypes.ArtifactFile{{
RelativePath: "model.bin",
SizeBytes: 8,
}}},
}
peer := sharedtypes.CollectivePeerPlan{NodeName: "peer", Endpoint: "endpoint", StagingPath: "root"}
client := fakeChunkClient{stats: map[string]int64{
"endpoint|root|model.bin.wdfanout.done": 4,
}}
if err := waitForPeerFanoutFileDone(context.Background(), client, spec, peer, "model.bin"); err != nil {
t.Fatalf("waitForPeerFanoutFileDone returned error: %v", err)
}
ticker := time.NewTicker(time.Hour)
defer ticker.Stop()
adapter := NewAdapterWithFFI(fakeChunkClient{stats: map[string]int64{
"endpoint|root|model.bin": 2,
}}, slog.New(slog.NewTextHandler(io.Discard, nil)), false)
err := adapter.waitForPeerFileSize(context.Background(), peerFileWait{
peer: peer,
relativePath: "model.bin",
minSize: 4,
deadline: time.Now().Add(-time.Millisecond),
ticker: ticker,
fileKind: "file",
})
if err == nil || !strings.Contains(err.Error(), "current=2") {
t.Fatalf("expected current-size timeout, got %v", err)
}
err = adapter.waitForPeerFileSize(context.Background(), peerFileWait{
peer: peer,
relativePath: "missing.bin",
minSize: 1,
deadline: time.Now().Add(-time.Millisecond),
ticker: ticker,
fileKind: "file",
})
if err == nil || !strings.Contains(err.Error(), "not found") {
t.Fatalf("expected stat error timeout, got %v", err)
}
}
func TestAdapterSmallBranchHelpers(t *testing.T) {
t.Setenv("WD_FORCE_TCP_FALLBACK", "off")
if forceTCPFallbackDataPlane() {
t.Fatalf("expected off to disable forced TCP fallback")
}
if effectiveChunkRetryLimit(sharedtypes.TransferSpec{RetryLimit: 3}) != 3 {
t.Fatalf("expected explicit retry limit")
}
if transferUsesHuggingFaceSource(sharedtypes.TransferSpec{SourceSegments: []sharedtypes.SourceSegmentPlan{{
SourceEndpoint: sharedtypes.SourceEndpoint{SourceType: "external", Endpoint: "https://huggingface.co"},
}}}) != true {
t.Fatalf("expected huggingface source detection")
}
}
func TestAdapterExecutesNPeerFanoutSourcePhaseWithNoopPeerFetches(t *testing.T) {
t.Parallel()
sourceRoot := t.TempDir()
targetRoot := t.TempDir()
payload := []byte("npeer-source-payload")
if err := os.WriteFile(filepath.Join(sourceRoot, "model.bin"), payload, 0o600); err != nil {
t.Fatalf("WriteFile returned error: %v", err)
}
var pushed []sharedtypes.PushCollectiveChunkRequest
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/collectives/chunks/push" {
t.Fatalf("unexpected path %s", r.URL.Path)
}
var req sharedtypes.PushCollectiveChunkRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
t.Fatalf("Decode push request returned error: %v", err)
}
pushed = append(pushed, req)
_ = json.NewEncoder(w).Encode(sharedtypes.PushCollectiveChunkResponse{Accepted: true})
}))
defer server.Close()
spec := sharedtypes.TransferSpec{
TaskID: "npeer-exec",
ArtifactKey: "artifact",
TransferMode: sharedtypes.TransferModePartialPullAllGather,
TargetTempPath: targetRoot,
PreserveExisting: true,
ChunkSizeBytes: 4,
Parallelism: 2,
LogicalManifest: sharedtypes.LogicalManifest{Files: []sharedtypes.ArtifactFile{{
RelativePath: "model.bin",
SizeBytes: int64(len(payload)),
Chunkable: true,
}}},
SourceSegments: []sharedtypes.SourceSegmentPlan{{
SourceEndpoint: sharedtypes.SourceEndpoint{Path: sourceRoot},
}},
CollectiveSpec: sharedtypes.CollectiveSpec{
SessionID: "session",
Ring: &sharedtypes.RingPeerPlan{Rank: 0, WorldSize: 3, SelfNode: "rank0", SelfEndpoint: server.URL},
Peers: []sharedtypes.CollectivePeerPlan{
{NodeName: "rank0", Rank: 0, Endpoint: server.URL},
{NodeName: "rank1", Rank: 1, Endpoint: "http://rank1:18080"},
{NodeName: "rank2", Rank: 2, Endpoint: "http://rank2:18080"},
},
},
}
np, ok := resolveNPeerFanoutSpec(spec)
if !ok {
t.Fatalf("expected n-peer spec to resolve")
}
adapter := NewAdapterWithOptions(LocalChunkClient{}, slog.New(slog.NewTextHandler(io.Discard, nil)), AdapterOptions{
EnableFFI: false,
ForceTCPFallback: true,
})
result, transportPath, handled, err := adapter.executeNPeerFanout(context.Background(), spec, spec.LogicalManifest.Files, np)
if err != nil {
t.Fatalf("executeNPeerFanout returned error: %v", err)
}
if !handled || transportPath != sharedtypes.TransportPathTCPFallback {
t.Fatalf("unexpected fanout handling: handled=%v transportPath=%q", handled, transportPath)
}
if result.BytesTransferred != int64(len(payload)) {
t.Fatalf("expected source payload bytes, got %#v", result)
}
if got, err := os.ReadFile(filepath.Join(targetRoot, "model.bin")); err != nil || string(got) != string(payload) {
t.Fatalf("unexpected target payload %q err=%v", string(got), err)
}
if len(pushed) == 0 {
t.Fatalf("expected source-done metadata push")
}
if _, err := os.Stat(filepath.Join(targetRoot, fanoutStagingReadyMarkerRelativePath("model.bin"))); err != nil {
t.Fatalf("expected source ready marker: %v", err)
}
}
func TestAdapterExecutesNPeerFanoutWithPeerStagingOwnerFetch(t *testing.T) {
t.Parallel()
sourceRoot := t.TempDir()
targetRoot := t.TempDir()
peerStaging := t.TempDir()
sourcePayload := []byte("source-half-0000")
peerPayload := []byte("peer-half-1111")
if err := os.WriteFile(filepath.Join(sourceRoot, "model.bin"), sourcePayload, 0o600); err != nil {
t.Fatalf("WriteFile source returned error: %v", err)
}
if err := os.WriteFile(filepath.Join(peerStaging, "model.bin"), peerPayload, 0o600); err != nil {
t.Fatalf("WriteFile peer staging returned error: %v", err)
}
if err := os.WriteFile(filepath.Join(peerStaging, fanoutStagingReadyMarkerRelativePath("model.bin")), []byte("1"), 0o600); err != nil {
t.Fatalf("WriteFile peer ready marker returned error: %v", err)
}
spec := sharedtypes.TransferSpec{
TaskID: "npeer-peer-fetch",
ArtifactKey: "artifact",
TransferMode: sharedtypes.TransferModePartialPullAllGather,
TargetTempPath: targetRoot,
PreserveExisting: true,
ChunkSizeBytes: 4,
Parallelism: 2,
LogicalManifest: sharedtypes.LogicalManifest{Files: []sharedtypes.ArtifactFile{{
RelativePath: "model.bin",
SizeBytes: int64(len(sourcePayload)),
Chunkable: true,
}}},
SourceSegments: []sharedtypes.SourceSegmentPlan{{
SourceEndpoint: sharedtypes.SourceEndpoint{Path: sourceRoot},
}},
CollectiveSpec: sharedtypes.CollectiveSpec{
SessionID: "session",
Ring: &sharedtypes.RingPeerPlan{Rank: 0, WorldSize: 2, SelfNode: "rank0", SelfEndpoint: "http://rank0:18080"},
Peers: []sharedtypes.CollectivePeerPlan{
{NodeName: "rank0", Rank: 0, Endpoint: "http://rank0:18080", OwnedRanges: []sharedtypes.ByteRange{{RelativePath: "model.bin", Start: 0, End: 4}}},
{NodeName: "rank1", Rank: 1, Endpoint: "http://rank1:18080", StagingPath: peerStaging, OwnedRanges: []sharedtypes.ByteRange{{RelativePath: "model.bin", Start: 5, End: 9}}},
},
},
}
np, ok := resolveNPeerFanoutSpec(spec)
if !ok {
t.Fatalf("expected n-peer spec to resolve")
}
adapter := NewAdapterWithOptions(LocalChunkClient{}, slog.New(slog.NewTextHandler(io.Discard, nil)), AdapterOptions{
EnableFFI: false,
ForceTCPFallback: true,
})
result, transportPath, handled, err := adapter.executeNPeerFanout(context.Background(), spec, spec.LogicalManifest.Files, np)
if err != nil {
t.Fatalf("executeNPeerFanout returned error: %v", err)
}
if !handled || transportPath != sharedtypes.TransportPathTCPFallback {
t.Fatalf("unexpected fanout result handled=%v transport=%q", handled, transportPath)
}
if result.BytesTransferred <= int64(len(sourcePayload)) {
t.Fatalf("expected peer fetch bytes to be merged, got %#v", result)
}
targetPayload, err := os.ReadFile(filepath.Join(targetRoot, "model.bin"))
if err != nil {
t.Fatalf("ReadFile target returned error: %v", err)
}
if string(targetPayload[5:9]) != string(peerPayload[5:9]) {
t.Fatalf("expected peer range to be copied into target, got %q", string(targetPayload))
}
}
func TestCollectivePayloadBuildApplyAndFallbackHelpers(t *testing.T) {
t.Parallel()
targetRoot := t.TempDir()
if err := os.WriteFile(filepath.Join(targetRoot, "model.bin"), []byte("abcdefghijklmnop"), 0o600); err != nil {
t.Fatalf("WriteFile source payload returned error: %v", err)
}
spec := sharedtypes.TransferSpec{
TaskID: "collective",
ArtifactKey: "artifact",
TargetTempPath: targetRoot,
ChunkSizeBytes: 4,
LogicalManifest: sharedtypes.LogicalManifest{Files: []sharedtypes.ArtifactFile{{
RelativePath: "model.bin",
SizeBytes: 16,
}}},
CollectiveSpec: sharedtypes.CollectiveSpec{Ring: &sharedtypes.RingPeerPlan{Rank: 1, WorldSize: 3}},
}
adapter := NewAdapterWithOptions(LocalChunkClient{}, slog.New(slog.NewTextHandler(io.Discard, nil)), AdapterOptions{ForceTCPFallback: true})
ranges := []sharedtypes.ByteRange{
{RelativePath: "model.bin", Start: 0, End: 4},
{RelativePath: "model.bin", Start: 8, End: 12},
{RelativePath: "skip.bin", Start: 5, End: 5},
}
payloads, err := adapter.buildLocalPayloads(spec, ranges, targetRoot)
if err != nil {
t.Fatalf("buildLocalPayloads returned error: %v", err)
}
if len(payloads) != 2 || string(payloads[0].Data) != "abcd" || payloads[1].Chunk.Offset != 8 {
t.Fatalf("unexpected local payloads: %#v", payloads)
}
outPath := filepath.Join(targetRoot, "out.bin")
file, err := os.OpenFile(outPath, os.O_CREATE|os.O_RDWR, 0o600)
if err != nil {
t.Fatalf("OpenFile returned error: %v", err)
}
defer file.Close()
result, err := adapter.applyCollectivePayloads(targetRoot, map[string]*os.File{"model.bin": file}, payloads)
if err != nil {
t.Fatalf("applyCollectivePayloads returned error: %v", err)
}
if result.BytesTransferred != 8 || result.SucceededChunks != 2 {
t.Fatalf("unexpected collective apply result: %#v", result)
}
badPayload := payloads[0]
badPayload.Chunk.CRC32C = "bad"
if _, err := adapter.applyCollectivePayloads(targetRoot, map[string]*os.File{"model.bin": file}, []sharedtypes.CollectiveChunkPayload{badPayload}); err == nil {
t.Fatalf("expected CRC mismatch to fail")
}
if _, err := adapter.applyCollectivePayloads(targetRoot, nil, payloads[:1]); err == nil {
t.Fatalf("expected missing prepared file to fail")
}
owner := sharedtypes.CollectivePeerPlan{NodeName: "peer-a", Rank: 2, Endpoint: "peer", StagingPath: "/stage"}
relayPayloads := []sharedtypes.CollectiveChunkPayload{
{Chunk: sharedtypes.TransferredChunk{RelativePath: "b.bin", Offset: 4, Size: 4}, RelayOffset: 44, RelayRDMA: &sharedtypes.RelayRDMAHint{SessionID: "s2"}},
{Chunk: sharedtypes.TransferredChunk{RelativePath: "a.bin", Offset: 0, Size: 4}, RelayOffset: 11, RelayRDMA: &sharedtypes.RelayRDMAHint{SessionID: "s1", Persistent: true}},
{Chunk: sharedtypes.TransferredChunk{RelativePath: "a.bin", Offset: 8, Size: 4}, RelayOffset: 22, RelayRDMA: &sharedtypes.RelayRDMAHint{SessionID: "s1"}},
{Chunk: sharedtypes.TransferredChunk{RelativePath: "ignored.bin", Offset: 0, Size: 0}, RelayRDMA: &sharedtypes.RelayRDMAHint{SessionID: "ignored"}},
}
subSpec := buildCollectiveRangeFetchSpec(spec, owner, relayPayloads, 3)
if subSpec.TransferMode != sharedtypes.TransferModeSingleSourceDirect || subSpec.Parallelism != 2 || len(subSpec.SourceSegments) != 2 {
t.Fatalf("unexpected relay fetch spec: %#v", subSpec)
}
if subSpec.SourceSegments[0].SourceEndpoint.RelayRDMA == relayPayloads[1].RelayRDMA {
t.Fatalf("expected relay RDMA hint to be cloned")
}
if subSpec.SourceSegments[0].ByteRanges[0].RelativePath != "b.bin" || subSpec.SourceSegments[0].ByteRanges[0].RelayOffset != 44 {
t.Fatalf("expected first relay session to preserve first-seen order: %#v", subSpec.SourceSegments[0].ByteRanges)
}
if subSpec.SourceSegments[1].ByteRanges[0].RelativePath != "a.bin" || subSpec.SourceSegments[1].ByteRanges[0].RelayOffset != 11 {
t.Fatalf("expected relay byte ranges within a session to be sorted and preserve relay offsets: %#v", subSpec.SourceSegments[1].ByteRanges)
}
peerByRank := collectivePeersByRank([]sharedtypes.CollectivePeerPlan{owner})
if got := collectiveOwnerRank(0, 3, 1); got != 2 {
t.Fatalf("unexpected owner rank %d", got)
}
ownerPeer, jobs, err := buildOwnerFetchFallbackJobs(peerByRank, 2)
if err != nil {
t.Fatalf("buildOwnerFetchFallbackJobs returned error: %v", err)
}
if ownerPeer.NodeName != "peer-a" || len(jobs) != 0 {
t.Fatalf("expected owner with no ranges and zero jobs, got owner=%#v jobs=%#v", ownerPeer, jobs)
}
owner.OwnedRanges = []sharedtypes.ByteRange{{RelativePath: "model.bin", Start: 1, End: 5}}
_, jobs, err = buildOwnerFetchFallbackJobs(collectivePeersByRank([]sharedtypes.CollectivePeerPlan{owner}), 2)
if err != nil || len(jobs) != 1 || jobs[0].length != 4 {
t.Fatalf("unexpected owner jobs: jobs=%#v err=%v", jobs, err)
}
if _, _, err := buildOwnerFetchFallbackJobs(peerByRank, 9); err == nil {
t.Fatalf("expected missing owner rank to fail")
}
}
func TestAdapterExecuteOwnerFetchFallbackCopiesPeerRanges(t *testing.T) {
t.Parallel()
targetRoot := t.TempDir()
peer0Root := t.TempDir()
peer2Root := t.TempDir()
if err := os.WriteFile(filepath.Join(peer0Root, "model.bin"), []byte("abcdefghij"), 0o600); err != nil {
t.Fatalf("WriteFile peer0 returned error: %v", err)
}
if err := os.WriteFile(filepath.Join(peer2Root, "model.bin"), []byte("klmnopqrst"), 0o600); err != nil {
t.Fatalf("WriteFile peer2 returned error: %v", err)
}
targetFile, err := os.OpenFile(filepath.Join(targetRoot, "model.bin"), os.O_CREATE|os.O_RDWR, 0o600)
if err != nil {
t.Fatalf("OpenFile target returned error: %v", err)
}
defer targetFile.Close()
spec := sharedtypes.TransferSpec{
TaskID: "owner-fetch",
ArtifactKey: "artifact",
TargetTempPath: targetRoot,
ChunkSizeBytes: 4,
LogicalManifest: sharedtypes.LogicalManifest{Files: []sharedtypes.ArtifactFile{{
RelativePath: "model.bin",
SizeBytes: 10,
}}},
CollectiveSpec: sharedtypes.CollectiveSpec{
Ring: &sharedtypes.RingPeerPlan{Rank: 1, WorldSize: 3},
Peers: []sharedtypes.CollectivePeerPlan{
{NodeName: "rank0", Rank: 0, Endpoint: "rank0", StagingPath: peer0Root, OwnedRanges: []sharedtypes.ByteRange{{RelativePath: "model.bin", Start: 0, End: 4}}},
{NodeName: "rank1", Rank: 1},
{NodeName: "rank2", Rank: 2, Endpoint: "rank2", StagingPath: peer2Root, OwnedRanges: []sharedtypes.ByteRange{{RelativePath: "model.bin", Start: 4, End: 8}}},
},
},
}
adapter := NewAdapterWithOptions(LocalChunkClient{}, slog.New(slog.NewTextHandler(io.Discard, nil)), AdapterOptions{ForceTCPFallback: true})
result, transportPath, err := adapter.executeOwnerFetchFallback(context.Background(), spec, map[string]*os.File{"model.bin": targetFile})
if err != nil {
t.Fatalf("executeOwnerFetchFallback returned error: %v", err)
}
if transportPath != sharedtypes.TransportPathTCPFallback || result.BytesTransferred != 8 || result.SucceededChunks != 2 {
t.Fatalf("unexpected owner-fetch fallback result: %#v transport=%q", result, transportPath)
}
if _, _, err := buildOwnerFetchFallbackJobs(map[int32]sharedtypes.CollectivePeerPlan{}, 7); err == nil {
t.Fatalf("expected missing fallback owner to fail")
}
_, emptyTransport, err := adapter.executeOwnerFetchFallback(context.Background(), sharedtypes.TransferSpec{
CollectiveSpec: sharedtypes.CollectiveSpec{Ring: &sharedtypes.RingPeerPlan{Rank: 0, WorldSize: 1}},
}, nil)
if err != nil || emptyTransport != sharedtypes.TransportPathTCPFallback {
t.Fatalf("expected worldsize<=1 owner-fetch no-op, transport=%q err=%v", emptyTransport, err)
}
}
func TestAdapterExecuteOwnerFetchWithDirectPullFallsBackToChunkClient(t *testing.T) {
t.Parallel()
targetRoot := t.TempDir()
peerRoot := t.TempDir()
targetFile := filepath.Join(targetRoot, "model.bin")
if err := os.WriteFile(targetFile, []byte("----------"), 0o600); err != nil {
t.Fatalf("WriteFile target returned error: %v", err)
}
if err := os.WriteFile(filepath.Join(peerRoot, "model.bin"), []byte("abcdefghij"), 0o600); err != nil {
t.Fatalf("WriteFile peer returned error: %v", err)
}
spec := sharedtypes.TransferSpec{
TaskID: "owner-direct",
ArtifactKey: "artifact",
TargetTempPath: targetRoot,
ChunkSizeBytes: 4,
LogicalManifest: sharedtypes.LogicalManifest{Files: []sharedtypes.ArtifactFile{{
RelativePath: "model.bin",
SizeBytes: 10,
}}},
CollectiveSpec: sharedtypes.CollectiveSpec{
Ring: &sharedtypes.RingPeerPlan{Rank: 1, WorldSize: 2},
Peers: []sharedtypes.CollectivePeerPlan{
{NodeName: "rank0", Rank: 0, Endpoint: "rank0", StagingPath: peerRoot, OwnedRanges: []sharedtypes.ByteRange{{RelativePath: "model.bin", Start: 2, End: 6}}},
{NodeName: "rank1", Rank: 1},
},
},
}
adapter := NewAdapterWithOptions(LocalChunkClient{}, slog.New(slog.NewTextHandler(io.Discard, nil)), AdapterOptions{
EnableFFI: false,
ForceTCPFallback: true,
})
result, transportPath, err := adapter.executeOwnerFetchWithDirectPull(context.Background(), spec)
if err != nil {
t.Fatalf("executeOwnerFetchWithDirectPull returned error: %v", err)
}
if transportPath != sharedtypes.TransportPathTCPFallback || result.BytesTransferred != 4 {
t.Fatalf("unexpected direct-pull fallback result: %#v transport=%q", result, transportPath)
}
targetPayload, err := os.ReadFile(targetFile)
if err != nil {
t.Fatalf("ReadFile target returned error: %v", err)
}
if string(targetPayload[2:6]) != "cdef" {
t.Fatalf("expected fallback bytes in target, got %q", string(targetPayload))
}
_, emptyTransport, err := adapter.executeOwnerFetchWithDirectPull(context.Background(), sharedtypes.TransferSpec{
CollectiveSpec: sharedtypes.CollectiveSpec{Ring: &sharedtypes.RingPeerPlan{Rank: 0, WorldSize: 1}},
})
if err != nil || emptyTransport != sharedtypes.TransportPathRDMA {
t.Fatalf("expected worldsize<=1 direct owner-fetch no-op, transport=%q err=%v", emptyTransport, err)
}
_, _, err = adapter.executeOwnerFetchWithDirectPull(context.Background(), sharedtypes.TransferSpec{
CollectiveSpec: sharedtypes.CollectiveSpec{
Ring: &sharedtypes.RingPeerPlan{Rank: 1, WorldSize: 2},
Peers: []sharedtypes.CollectivePeerPlan{{Rank: 1}},
},
})
if err == nil {
t.Fatalf("expected missing owner rank to fail")
}
}
func TestAdapterExecutePublishedOwnerFetchConsumesIterations(t *testing.T) {
t.Parallel()
targetRoot := t.TempDir()
peerRoot := t.TempDir()
targetFile := filepath.Join(targetRoot, "model.bin")
if err := os.WriteFile(targetFile, []byte("--------"), 0o600); err != nil {
t.Fatalf("WriteFile target returned error: %v", err)
}
if err := os.WriteFile(filepath.Join(peerRoot, "model.bin"), []byte("abcdefgh"), 0o600); err != nil {
t.Fatalf("WriteFile peer returned error: %v", err)
}
acknowledged := make(map[int32]bool)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/v1/collectives/chunks/list":
var req sharedtypes.ListCollectiveChunksRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
t.Fatalf("Decode list request returned error: %v", err)
}
start := int64((req.Iteration - 1) * 4)
resp := sharedtypes.ListCollectiveChunksResponse{
TaskID: req.TaskID,
Iteration: req.Iteration,
ExpectedChunks: 1,
Chunks: []sharedtypes.CollectiveChunkPayload{{
Chunk: sharedtypes.TransferredChunk{
ChunkID: fmt.Sprintf("chunk-%d", req.Iteration),
RelativePath: "model.bin",
Offset: start,
Size: 4,
},
}},
}
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(resp); err != nil {
t.Fatalf("Encode list response returned error: %v", err)
}
case "/v1/collectives/step":
var req sharedtypes.CollectiveStepRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
t.Fatalf("Decode step request returned error: %v", err)
}
if !req.AcknowledgeOnly {
t.Fatalf("expected acknowledge-only step request: %#v", req)
}
acknowledged[req.Iteration] = true
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(sharedtypes.CollectiveStepResponse{TaskID: req.TaskID, Iteration: req.Iteration}); err != nil {
t.Fatalf("Encode step response returned error: %v", err)
}
default:
http.NotFound(w, r)
}
}))
defer server.Close()
spec := sharedtypes.TransferSpec{
TaskID: "published-owner",
ArtifactKey: "artifact",
TargetTempPath: targetRoot,
ChunkSizeBytes: 4,
TimeoutSeconds: 1,
LogicalManifest: sharedtypes.LogicalManifest{Files: []sharedtypes.ArtifactFile{{
RelativePath: "model.bin",
SizeBytes: 8,
}}},
CollectiveSpec: sharedtypes.CollectiveSpec{Ring: &sharedtypes.RingPeerPlan{Rank: 1, WorldSize: 2}},
}
owner := sharedtypes.CollectivePeerPlan{
NodeName: "rank0",
Rank: 0,
Endpoint: server.URL,
StagingPath: peerRoot,
OwnedRanges: []sharedtypes.ByteRange{
{RelativePath: "model.bin", Start: 0, End: 4},
{RelativePath: "model.bin", Start: 4, End: 8},
},
}
adapter := NewAdapterWithOptions(LocalChunkClient{}, slog.New(slog.NewTextHandler(io.Discard, nil)), AdapterOptions{
EnableFFI: false,
ForceTCPFallback: true,
})
result, transportPath, err := adapter.executePublishedOwnerFetch(context.Background(), spec, owner)
if err != nil {
t.Fatalf("executePublishedOwnerFetch returned error: %v", err)
}
if result.BytesTransferred != 8 || result.SucceededChunks != 2 || transportPath != sharedtypes.TransportPathTCPFallback {
t.Fatalf("unexpected published owner result: %#v transport=%q", result, transportPath)
}
if !acknowledged[1] || !acknowledged[2] {
t.Fatalf("expected both iterations to be acknowledged, got %#v", acknowledged)
}
targetPayload, err := os.ReadFile(targetFile)
if err != nil {
t.Fatalf("ReadFile target returned error: %v", err)
}
if string(targetPayload) != "abcdefgh" {
t.Fatalf("expected copied peer payload, got %q", string(targetPayload))
}
}
func TestAdapterExecuteSerialSymmetricFanoutMergesInitialAndPeerHalves(t *testing.T) {
t.Parallel()
sourceRoot := t.TempDir()
peerRoot := t.TempDir()
targetRoot := t.TempDir()
if err := os.WriteFile(filepath.Join(sourceRoot, "model.bin"), []byte("abcdefgh"), 0o600); err != nil {
t.Fatalf("WriteFile source returned error: %v", err)
}
if err := os.WriteFile(filepath.Join(peerRoot, "model.bin"), []byte("ijklmnop"), 0o600); err != nil {
t.Fatalf("WriteFile peer returned error: %v", err)
}
targetFile := filepath.Join(targetRoot, "model.bin")
if err := os.WriteFile(targetFile, []byte("--------"), 0o600); err != nil {
t.Fatalf("WriteFile target returned error: %v", err)
}
spec := sharedtypes.TransferSpec{
TaskID: "serial-fanout",
ArtifactKey: "artifact",
TransferMode: sharedtypes.TransferModePartialPullAllGather,
TargetTempPath: targetRoot,
TargetFinalPath: targetRoot,
PreserveExisting: true,
ChunkSizeBytes: 4,
Parallelism: 1,
LogicalManifest: sharedtypes.LogicalManifest{Files: []sharedtypes.ArtifactFile{{
RelativePath: "model.bin",
SizeBytes: 8,
}}},
SourceSegments: []sharedtypes.SourceSegmentPlan{{
SourceID: "source",
SourceEndpoint: sharedtypes.SourceEndpoint{
SourceID: "source",
SourceType: "localfs",
Path: sourceRoot,
},
ByteRanges: []sharedtypes.ByteRange{{RelativePath: "model.bin", Start: 0, End: 4}},
}},
CollectiveSpec: sharedtypes.CollectiveSpec{
Ring: &sharedtypes.RingPeerPlan{Rank: 0, WorldSize: 3, SelfEndpoint: "self"},
},
}
role := symmetricFanoutRole{
selfPeer: sharedtypes.CollectivePeerPlan{
NodeName: "rank0",
Rank: 0,
OwnedRanges: []sharedtypes.ByteRange{{RelativePath: "model.bin", Start: 0, End: 4}},
},
otherPeer: sharedtypes.CollectivePeerPlan{
NodeName: "rank1",
Rank: 1,
StagingPath: peerRoot,
OwnedRanges: []sharedtypes.ByteRange{{RelativePath: "model.bin", Start: 4, End: 8}},
},
}
adapter := NewAdapterWithOptions(LocalChunkClient{}, slog.New(slog.NewTextHandler(io.Discard, nil)), AdapterOptions{
EnableFFI: false,
ForceTCPFallback: true,
})
result, transportPath, err := adapter.executeSerialSymmetricFanout(context.Background(), spec, buildSymmetricFanoutSourceSpec(spec), role, time.Now())
if err != nil {
t.Fatalf("executeSerialSymmetricFanout returned error: %v", err)
}
if result.BytesTransferred != 8 || transportPath != sharedtypes.TransportPathTCPFallback {
t.Fatalf("unexpected serial fanout result: %#v transport=%q", result, transportPath)
}
targetPayload, err := os.ReadFile(targetFile)
if err != nil {
t.Fatalf("ReadFile target returned error: %v", err)
}
if string(targetPayload) != "abcdmnop" {
t.Fatalf("expected merged fanout payload, got %q", string(targetPayload))
}
}
func TestAdapterExecuteConcurrentSymmetricFanoutParallelBranch(t *testing.T) {
t.Setenv("WD_RDMA_SYMMETRIC_OWNER_FETCH_AFTER_INITIAL", "false")
sourceRoot := t.TempDir()
peerRoot := t.TempDir()
targetRoot := t.TempDir()
if err := os.WriteFile(filepath.Join(sourceRoot, "model.bin"), []byte("abcdefgh"), 0o600); err != nil {
t.Fatalf("WriteFile source returned error: %v", err)
}
if err := os.WriteFile(filepath.Join(peerRoot, "model.bin"), []byte("ijklmnop"), 0o600); err != nil {
t.Fatalf("WriteFile peer returned error: %v", err)
}
if err := os.WriteFile(filepath.Join(peerRoot, fanoutStagingReadyMarkerRelativePath("model.bin")), []byte("ready"), 0o600); err != nil {
t.Fatalf("WriteFile peer ready marker returned error: %v", err)
}
targetFile := filepath.Join(targetRoot, "model.bin")
if err := os.WriteFile(targetFile, []byte("--------"), 0o600); err != nil {
t.Fatalf("WriteFile target returned error: %v", err)
}
spec := sharedtypes.TransferSpec{
TaskID: "parallel-fanout",
ArtifactKey: "artifact",
TransferMode: sharedtypes.TransferModePartialPullAllGather,
TargetTempPath: targetRoot,
TargetFinalPath: targetRoot,
PreserveExisting: true,
ChunkSizeBytes: 4,
Parallelism: 1,
LogicalManifest: sharedtypes.LogicalManifest{Files: []sharedtypes.ArtifactFile{{
RelativePath: "model.bin",
SizeBytes: 8,
}}},
SourceSegments: []sharedtypes.SourceSegmentPlan{{
SourceID: "source",
SourceEndpoint: sharedtypes.SourceEndpoint{
SourceID: "source",
SourceType: "localfs",
Path: sourceRoot,
},
ByteRanges: []sharedtypes.ByteRange{{RelativePath: "model.bin", Start: 0, End: 4}},
}},
CollectiveSpec: sharedtypes.CollectiveSpec{
Ring: &sharedtypes.RingPeerPlan{Rank: 0, WorldSize: 3, SelfEndpoint: "self"},
Peers: []sharedtypes.CollectivePeerPlan{
{
NodeName: "rank0",
Rank: 0,
OwnedRanges: []sharedtypes.ByteRange{{RelativePath: "model.bin", Start: 0, End: 4}},
},
{
NodeName: "rank1",
Rank: 1,
Endpoint: "peer",
StagingPath: peerRoot,
OwnedRanges: []sharedtypes.ByteRange{{RelativePath: "model.bin", Start: 4, End: 8}},
},
},
},
}
adapter := NewAdapterWithOptions(LocalChunkClient{}, slog.New(slog.NewTextHandler(io.Discard, nil)), AdapterOptions{
EnableFFI: false,
ForceTCPFallback: true,
})
role := symmetricFanoutRole{
selfPeer: spec.CollectiveSpec.Peers[0],
otherPeer: spec.CollectiveSpec.Peers[1],
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
result, transportPath, err := adapter.executeParallelSymmetricFanout(ctx, cancel, spec, buildSymmetricFanoutSourceSpec(spec), role, time.Now())
if err != nil {
t.Fatalf("executeParallelSymmetricFanout returned error: %v", err)
}
if result.BytesTransferred != 8 || transportPath != sharedtypes.TransportPathTCPFallback {
t.Fatalf("unexpected concurrent fanout result: result=%#v transport=%q", result, transportPath)
}
targetPayload, err := os.ReadFile(targetFile)
if err != nil {
t.Fatalf("ReadFile target returned error: %v", err)
}
if string(targetPayload) != "abcdmnop" {
t.Fatalf("expected merged fanout payload, got %q", string(targetPayload))
}
}
func TestCollectiveHTTPHelpersPushAckCloseAndSessionIDs(t *testing.T) {
t.Parallel()
var pushed []sharedtypes.PushCollectiveChunkRequest
var closed []sharedtypes.CloseRDMAExportRequest
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/v1/collectives/chunks/push":
var req sharedtypes.PushCollectiveChunkRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
t.Fatalf("Decode push returned error: %v", err)
}
pushed = append(pushed, req)
w.WriteHeader(http.StatusOK)
case "/v1/collectives/step":
var req sharedtypes.CollectiveStepRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
t.Fatalf("Decode step returned error: %v", err)
}
if req.Iteration == 404 {
http.NotFound(w, r)
return
}
if req.Iteration == 500 {
http.Error(w, "bad", http.StatusInternalServerError)
return
}
w.WriteHeader(http.StatusOK)
case "/v1/exports/rdma/close":
var req sharedtypes.CloseRDMAExportRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
t.Fatalf("Decode close returned error: %v", err)
}
closed = append(closed, req)
w.WriteHeader(http.StatusOK)
default:
t.Fatalf("unexpected path %s", r.URL.Path)
}
}))
defer server.Close()
adapter := NewAdapterWithOptions(LocalChunkClient{}, slog.New(slog.NewTextHandler(io.Discard, nil)), AdapterOptions{ForceTCPFallback: true})
adapter.httpClient = server.Client()
payloads := []sharedtypes.CollectiveChunkPayload{{
Chunk: sharedtypes.TransferredChunk{RelativePath: "model.bin", Offset: 0, Size: 4},
Data: []byte("data"),
RelayRDMA: &sharedtypes.RelayRDMAHint{SessionID: "relay-1"},
}}
target := collectivePushTarget{
endpoint: server.URL,
taskID: "task",
sessionID: "session",
iteration: 1,
transportPath: sharedtypes.TransportPathRDMA,
sourceNode: "node-a",
includeData: false,
}
if err := adapter.pushCollectivePayloadsToEndpointWithPayloads(context.Background(), target, payloads); err != nil {
t.Fatalf("pushCollectivePayloadsToEndpointWithPayloads returned error: %v", err)
}
if len(pushed) != 1 || len(pushed[0].Data) != 0 || pushed[0].RelayRDMA == payloads[0].RelayRDMA {
t.Fatalf("unexpected pushed payload: %#v", pushed)
}
if err := adapter.pushCollectivePayloads(context.Background(), sharedtypes.TransferSpec{
TaskID: "task",
CollectiveSpec: sharedtypes.CollectiveSpec{
Ring: &sharedtypes.RingPeerPlan{NextEndpoint: server.URL, SelfNode: "node-a"},
},
}, "session", 2, payloads); err != nil {
t.Fatalf("pushCollectivePayloads returned error: %v", err)
}
if len(pushed) != 2 || string(pushed[1].Data) != "data" {
t.Fatalf("expected second push to include data, got %#v", pushed)
}
if err := adapter.acknowledgeCollectiveIteration(context.Background(), server.URL, "task", "session", 3); err != nil {
t.Fatalf("acknowledgeCollectiveIteration returned error: %v", err)
}
if err := adapter.acknowledgeCollectiveIteration(context.Background(), server.URL, "task", "session", 404); err != nil {
t.Fatalf("404 acknowledgement should be ignored, got %v", err)
}
if err := adapter.acknowledgeCollectiveIteration(context.Background(), server.URL, "task", "session", 500); err == nil {
t.Fatalf("expected non-OK acknowledgement to fail")
}
sessions := make(map[string]string)
collectRelaySessions(server.URL, payloads, sessions)
collectRelaySessions(server.URL, []sharedtypes.CollectiveChunkPayload{{}}, sessions)
if ids := relaySessionIDs(payloads); len(ids) != 1 || ids[0] != "relay-1" {
t.Fatalf("unexpected relay session ids: %#v", ids)
}
adapter.closeRelaySessions(context.Background(), sessions)
if len(closed) != 1 || closed[0].SessionID != "relay-1" {
t.Fatalf("expected relay session close, got %#v", closed)
}
ctx, cancel := newDetachedCleanupContext(nil, time.Second)
cancel()
if ctx == nil {
t.Fatalf("expected detached cleanup context")
}
}
func TestWaitCollectivePayloadMetadataAndListBranches(t *testing.T) {
t.Parallel()
var calls int
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/collectives/chunks/list" {
t.Fatalf("unexpected path %s", r.URL.Path)
}
var req sharedtypes.ListCollectiveChunksRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
t.Fatalf("Decode list request returned error: %v", err)
}
calls++
switch req.TaskID {
case "bad-status":
http.Error(w, "bad", http.StatusInternalServerError)
case "bad-json":
_, _ = w.Write([]byte("{"))
case "wait-once":
if calls == 1 {
_ = json.NewEncoder(w).Encode(sharedtypes.ListCollectiveChunksResponse{ExpectedChunks: 2})
return
}
_ = json.NewEncoder(w).Encode(sharedtypes.ListCollectiveChunksResponse{
ExpectedChunks: 2,
Chunks: []sharedtypes.CollectiveChunkPayload{
{Chunk: sharedtypes.TransferredChunk{RelativePath: "a.bin", Size: 1}},
{Chunk: sharedtypes.TransferredChunk{RelativePath: "b.bin", Size: 1}},
},
})
default:
_ = json.NewEncoder(w).Encode(sharedtypes.ListCollectiveChunksResponse{
ExpectedChunks: 1,
Chunks: []sharedtypes.CollectiveChunkPayload{{Chunk: sharedtypes.TransferredChunk{RelativePath: "model.bin", Size: 1}}},
})
}
}))
defer server.Close()
adapter := NewAdapterWithOptions(LocalChunkClient{}, slog.New(slog.NewTextHandler(io.Discard, nil)), AdapterOptions{ForceTCPFallback: true})
adapter.httpClient = server.Client()
spec := sharedtypes.TransferSpec{
TaskID: "task",
CollectiveSpec: sharedtypes.CollectiveSpec{
Ring: &sharedtypes.RingPeerPlan{PrevEndpoint: server.URL},
},
}
payloads, err := adapter.waitCollectivePayloads(context.Background(), spec, "session", 1, []sharedtypes.ByteRange{{RelativePath: "model.bin", Start: 0, End: 1}}, time.Second)
if err != nil || len(payloads) != 1 {
t.Fatalf("waitCollectivePayloads returned payloads=%#v err=%v", payloads, err)
}
if _, err := adapter.waitCollectivePayloads(context.Background(), sharedtypes.TransferSpec{
CollectiveSpec: sharedtypes.CollectiveSpec{Ring: &sharedtypes.RingPeerPlan{PrevEndpoint: "://bad"}},
}, "session", 1, nil, time.Millisecond); err == nil {
t.Fatalf("expected bad prev endpoint to fail")
}
calls = 0
payloads, err = adapter.waitCollectivePayloadMetadata(context.Background(), server.URL, "wait-once", "session", 2, time.Second)
if err != nil || len(payloads) != 2 {
t.Fatalf("waitCollectivePayloadMetadata returned payloads=%#v err=%v", payloads, err)
}
if _, err := adapter.listCollectiveChunks(context.Background(), server.URL, collectiveListRequest("bad-status", "session", 1), "test list"); err == nil {
t.Fatalf("expected bad status list to fail")
}
if _, err := adapter.listCollectiveChunks(context.Background(), server.URL, collectiveListRequest("bad-json", "session", 1), "test list"); err == nil {
t.Fatalf("expected bad json list to fail")
}
if _, _, err := adapter.handleRelayCollectiveMetadataPoll(context.Background(), "task", 1, time.Now().Add(-time.Second), sharedtypes.ListCollectiveChunksResponse{ExpectedChunks: 2, Chunks: []sharedtypes.CollectiveChunkPayload{{}}}, time.NewTicker(time.Hour)); err == nil {
t.Fatalf("expected incomplete metadata poll to fail after deadline")
}
}
func TestAdapterHonorsForcedTCPFallbackForNativePath(t *testing.T) {
t.Setenv("WD_FORCE_TCP_FALLBACK", "true")
adapter := NewAdapterWithOptions(LocalChunkClient{}, nil, AdapterOptions{
EnableFFI: true,
ForceTCPFallback: false,
})
if !adapter.forceTCPFallback {
t.Fatalf("expected environment to force TCP fallback")
}
if got := adapter.determineTransportPath(); got != sharedtypes.TransportPathTCPFallback {
t.Fatalf("expected forced transport path TCP fallback, got %q", got)
}
if _, err := adapter.tryExecuteDirectPullWithFFI(sharedtypes.TransferSpec{TaskID: "task-a"}); err == nil {
t.Fatalf("expected forced TCP fallback to skip RDMA FFI")
}
}
func TestChunkCRCInjectionBypassesNativeFFI(t *testing.T) {
t.Setenv("WD_INJECT_CHUNK_CRC_MISMATCH_ONCE", "true")
if !shouldBypassDirectPullFFI(sharedtypes.TransferSpec{}, []sharedtypes.ArtifactFile{{RelativePath: "weights.bin"}}) {
t.Fatalf("expected CRC fault injection to use Go chunk path for retry validation")
}
}
func TestAdapterExecutesConcurrentStripedTransfer(t *testing.T) {
t.Parallel()
sourceRoot := t.TempDir()
targetRoot := t.TempDir()
if err := os.WriteFile(filepath.Join(sourceRoot, "weights.bin"), []byte("abcdefghijklmnopqrstuvwxyz"), 0o644); err != nil {
t.Fatalf("WriteFile returned error: %v", err)
}
adapter := NewAdapter(LocalChunkClient{}, nil)
result, err := adapter.Execute(context.Background(), sharedtypes.TransferSpec{
TaskID: "task-1",
ArtifactKey: "artifact-1",
TransferMode: sharedtypes.TransferModeDirectStriped,
TargetTempPath: targetRoot,
ChunkSizeBytes: 4,
Parallelism: 2,
LogicalManifest: sharedtypes.LogicalManifest{
ArtifactKey: "artifact-1",
ChunkSizeBytes: 4,
Digest: "digest-1",
Files: []sharedtypes.ArtifactFile{{
RelativePath: "weights.bin",
SizeBytes: 26,
Kind: sharedtypes.ArtifactFileKindSafeTensors,
Chunkable: true,
Required: true,
}},
},
SourceSegments: []sharedtypes.SourceSegmentPlan{
{
SourceID: "source-1",
SourceEndpoint: sharedtypes.SourceEndpoint{
SourceID: "source-1",
SourceType: "rdma",
Endpoint: "10.0.0.1",
Path: sourceRoot,
},
ByteRanges: []sharedtypes.ByteRange{{RelativePath: "weights.bin", Start: 0, End: 13}},
},
{
SourceID: "source-2",
SourceEndpoint: sharedtypes.SourceEndpoint{
SourceID: "source-2",
SourceType: "rdma",
Endpoint: "10.0.0.2",
Path: sourceRoot,
},
ByteRanges: []sharedtypes.ByteRange{{RelativePath: "weights.bin", Start: 13, End: 26}},
},
},
})
if err != nil {
t.Fatalf("Execute returned error: %v", err)
}
if result.ChunkCount != 2 {
t.Fatalf("expected 2 chunks, got %d", result.ChunkCount)
}
data, err := os.ReadFile(filepath.Join(targetRoot, "weights.bin"))
if err != nil {
t.Fatalf("ReadFile returned error: %v", err)
}
if string(data) != "abcdefghijklmnopqrstuvwxyz" {
t.Fatalf("unexpected target content: %s", string(data))
}
if result.SolidifiedManifest == nil || len(result.SolidifiedManifest.Chunks) != 2 {
t.Fatalf("expected solidified manifest to contain 2 chunks")
}
}
func TestAdapterExecutesSingleSourceDirectoryTransfer(t *testing.T) {
t.Parallel()
sourceRoot := t.TempDir()
targetRoot := t.TempDir()
if err := os.WriteFile(filepath.Join(sourceRoot, "shard0.bin"), []byte("abcdefgh"), 0o644); err != nil {
t.Fatalf("WriteFile returned error: %v", err)
}
adapter := NewAdapter(LocalChunkClient{}, nil)
result, err := adapter.Execute(context.Background(), sharedtypes.TransferSpec{
TaskID: "task-single-dir",
ArtifactKey: "artifact-dir",
TransferMode: sharedtypes.TransferModeSingleSourceDirect,
TargetTempPath: targetRoot,
ChunkSizeBytes: 4,
Parallelism: 1,
LogicalManifest: sharedtypes.LogicalManifest{
ArtifactKey: "artifact-dir",
ChunkSizeBytes: 4,
Digest: "digest-single-dir",
Files: []sharedtypes.ArtifactFile{{
RelativePath: "shard0.bin",
SizeBytes: 8,
Kind: sharedtypes.ArtifactFileKindAuxiliary,
Chunkable: false,
Required: true,
}},
},
SourceSegments: []sharedtypes.SourceSegmentPlan{{
SourceID: "source-1",
SourceEndpoint: sharedtypes.SourceEndpoint{
SourceID: "source-1",
SourceType: "rdma",
Path: sourceRoot,
},
}},
})
if err != nil {
t.Fatalf("Execute returned error: %v", err)
}
if result.BytesTransferred != 8 {
t.Fatalf("expected 8 transferred bytes, got %d", result.BytesTransferred)
}
if result.ChunkCount != 2 {
t.Fatalf("expected 2 chunks, got %d", result.ChunkCount)
}
data, err := os.ReadFile(filepath.Join(targetRoot, "shard0.bin"))
if err != nil {
t.Fatalf("ReadFile returned error: %v", err)
}
if string(data) != "abcdefgh" {
t.Fatalf("unexpected target content: %s", string(data))
}
}
func TestAdapterExecutesMultiFileDirectoryPerFileTransfer(t *testing.T) {
t.Parallel()
sourceRoot := t.TempDir()
targetRoot := t.TempDir()
files := map[string][]byte{
"config.json": []byte(`{"model":"qwen"}`),
"nested/tokenizer_config.json": []byte(`{"tokenizer":"qwen"}`),
"model-00001-of-00002.safetensors": bytes.Repeat([]byte("a"), 4096),
"model-00002-of-00002.safetensors": bytes.Repeat([]byte("b"), 2048),
}
for rel, payload := range files {
path := filepath.Join(sourceRoot, filepath.FromSlash(rel))
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
t.Fatalf("MkdirAll returned error: %v", err)
}
if err := os.WriteFile(path, payload, 0o644); err != nil {
t.Fatalf("WriteFile returned error: %v", err)
}
}
adapter := NewAdapterWithOptions(LocalChunkClient{}, nil, AdapterOptions{EnableFFI: false, ForceTCPFallback: true})
result, err := adapter.Execute(context.Background(), sharedtypes.TransferSpec{
TaskID: "task-single-dir-multi",
ArtifactKey: "qwen3-8b",
TransferMode: sharedtypes.TransferModeSingleSourceDirect,
TargetTempPath: targetRoot,
ChunkSizeBytes: 1024,
Parallelism: 4,
LogicalManifest: sharedtypes.LogicalManifest{
ArtifactKey: "qwen3-8b",
ChunkSizeBytes: 1024,
Files: []sharedtypes.ArtifactFile{
{RelativePath: "config.json", SizeBytes: int64(len(files["config.json"])), Kind: sharedtypes.ArtifactFileKindJSON, Chunkable: false, Required: true},
{RelativePath: "model-00001-of-00002.safetensors", SizeBytes: int64(len(files["model-00001-of-00002.safetensors"])), Kind: sharedtypes.ArtifactFileKindSafeTensors, Chunkable: true, Required: true},
{RelativePath: "model-00002-of-00002.safetensors", SizeBytes: int64(len(files["model-00002-of-00002.safetensors"])), Kind: sharedtypes.ArtifactFileKindSafeTensors, Chunkable: true, Required: true},
{RelativePath: "nested/tokenizer_config.json", SizeBytes: int64(len(files["nested/tokenizer_config.json"])), Kind: sharedtypes.ArtifactFileKindTokenizer, Chunkable: false, Required: true},
},
},
SourceSegments: []sharedtypes.SourceSegmentPlan{{
SourceID: "source-1",
SourceEndpoint: sharedtypes.SourceEndpoint{SourceID: "source-1", SourceType: "node", Path: sourceRoot},
}},
})
if err != nil {
t.Fatalf("Execute returned error: %v", err)
}
if result.ChunkCount < int32(len(files)) {
t.Fatalf("expected per-file chunks to be recorded, got %#v", result)
}
for rel, expected := range files {
actual, err := os.ReadFile(filepath.Join(targetRoot, filepath.FromSlash(rel)))
if err != nil {
t.Fatalf("ReadFile %s returned error: %v", rel, err)
}
if !bytes.Equal(actual, expected) {
t.Fatalf("unexpected copied content for %s", rel)
}
}
}
func TestBuildPerFileTransferSpecFiltersStripedRanges(t *testing.T) {
t.Parallel()
spec := sharedtypes.TransferSpec{
TransferMode: sharedtypes.TransferModeDirectStriped,
LogicalManifest: sharedtypes.LogicalManifest{
Files: []sharedtypes.ArtifactFile{
{RelativePath: "a.safetensors", SizeBytes: 16, Chunkable: true, Required: true},
{RelativePath: "b.json", SizeBytes: 4, Chunkable: false, Required: true},
},
},
SourceSegments: []sharedtypes.SourceSegmentPlan{
{
SourceID: "s0",
SourceEndpoint: sharedtypes.SourceEndpoint{
SourceID: "s0",
Path: "/src",
},
ByteRanges: []sharedtypes.ByteRange{
{RelativePath: "a.safetensors", Start: 0, End: 8},
{RelativePath: "b.json", Start: 0, End: 4},
},
},
{
SourceID: "s1",
SourceEndpoint: sharedtypes.SourceEndpoint{
SourceID: "s1",
Path: "/src",
},
ByteRanges: []sharedtypes.ByteRange{
{RelativePath: "a.safetensors", Start: 8, End: 16},
},
},
},
}
sub := buildPerFileTransferSpec(spec, spec.LogicalManifest.Files[0])
if len(sub.LogicalManifest.Files) != 1 || sub.LogicalManifest.Files[0].RelativePath != "a.safetensors" {
t.Fatalf("unexpected logical manifest files: %+v", sub.LogicalManifest.Files)
}
if len(sub.SourceSegments) != 2 {
t.Fatalf("expected 2 filtered source segments, got %d", len(sub.SourceSegments))
}
for _, segment := range sub.SourceSegments {
for _, rng := range segment.ByteRanges {
if rng.RelativePath != "a.safetensors" {
t.Fatalf("unexpected range after filtering: %+v", rng)
}
}
}
}
func TestBuildPerFileTransferSpecDowngradesNonChunkableFanoutFile(t *testing.T) {
t.Parallel()
spec := sharedtypes.TransferSpec{
TransferMode: sharedtypes.TransferModePartialPullAllGather,
LogicalManifest: sharedtypes.LogicalManifest{
Files: []sharedtypes.ArtifactFile{
{RelativePath: "config.json", SizeBytes: 128, Chunkable: false, Required: true},
},
},
SourceSegments: []sharedtypes.SourceSegmentPlan{{
SourceID: "src",
SourceEndpoint: sharedtypes.SourceEndpoint{
SourceID: "src",
Path: "/hub",
},
ByteRanges: []sharedtypes.ByteRange{{RelativePath: "config.json", Start: 0, End: 128}},
}},
CollectiveSpec: sharedtypes.CollectiveSpec{
Mode: sharedtypes.CollectiveModeRing,
Ring: &sharedtypes.RingPeerPlan{SelfNode: "n0", WorldSize: 2},
Peers: []sharedtypes.CollectivePeerPlan{
{NodeName: "n0", Rank: 0, OwnedRanges: []sharedtypes.ByteRange{{RelativePath: "config.json", Start: 0, End: 128}}},
{NodeName: "n2", Rank: 1},
},
},
}
sub := buildPerFileTransferSpec(spec, spec.LogicalManifest.Files[0])
if sub.TransferMode != sharedtypes.TransferModePartialPullAllGather {
t.Fatalf("expected n-peer whole-file fanout to stay on partial pull allgather, got %s", sub.TransferMode)
}
if sub.CollectiveSpec.Ring == nil || len(sub.CollectiveSpec.Peers) != 2 {
t.Fatalf("expected collective spec preserved, got %+v", sub.CollectiveSpec)
}
if len(sub.SourceSegments) != 1 || len(sub.SourceSegments[0].ByteRanges) != 0 {
t.Fatalf("expected rank-0 to keep one whole-file source segment, got %+v", sub.SourceSegments)
}
if got := sub.CollectiveSpec.Peers[0].OwnedRanges; len(got) != 1 || got[0].Start != 0 || got[0].End != 128 {
t.Fatalf("expected rank-0 peer to own full file, got %+v", got)
}
if got := sub.CollectiveSpec.Peers[1].OwnedRanges; len(got) != 0 {
t.Fatalf("expected non-owner peer to have no owned ranges, got %+v", got)
}
}
func TestExecuteNPeerOwnerFetchNoopDoesNotDowngradeTransport(t *testing.T) {
t.Parallel()
adapter := NewAdapter(LocalChunkClient{}, nil)
result, transportPath, err := adapter.executeNPeerOwnerFetch(
context.Background(),
sharedtypes.TransferSpec{
TaskID: "noop-npeer-fetch",
CollectiveSpec: sharedtypes.CollectiveSpec{
Mode: sharedtypes.CollectiveModeRing,
Ring: &sharedtypes.RingPeerPlan{SelfNode: "n0", WorldSize: 2},
},
},
sharedtypes.CollectivePeerPlan{
NodeName: "n2",
Rank: 1,
Endpoint: "http://192.168.200.14:18080",
OwnedRanges: nil,
},
)
if err != nil {
t.Fatalf("executeNPeerOwnerFetch returned error: %v", err)
}
if result.BytesTransferred != 0 || result.ChunkCount != 0 {
t.Fatalf("expected empty result for no-op peer step, got %+v", result)
}
if transportPath != "" {
t.Fatalf("expected no-op peer step to leave transport path empty, got %q", transportPath)
}
}
func TestResolveFilesHonorsCanceledContextForHuggingFaceManifest(t *testing.T) {
t.Parallel()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
_ = json.NewEncoder(w).Encode(map[string]any{
"sha": "commit-1",
"siblings": []map[string]any{
{
"rfilename": "model.safetensors",
"size": 8,
},
},
})
}))
defer server.Close()
ctx, cancel := context.WithCancel(context.Background())
cancel()
_, err := resolveFiles(ctx, sharedtypes.TransferSpec{
ChunkSizeBytes: 1024,
SourceSegments: []sharedtypes.SourceSegmentPlan{{
SourceEndpoint: sharedtypes.SourceEndpoint{
SourceType: "huggingface",
Endpoint: server.URL,
Path: "org/model",
},
}},
})
if !errors.Is(err, context.Canceled) {
t.Fatalf("expected context canceled, got %v", err)
}
}
func TestWaitForNextPollHonorsCanceledContext(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithCancel(context.Background())
cancel()
ticker := time.NewTicker(time.Hour)
defer ticker.Stop()
if err := waitForNextPoll(ctx, ticker); !errors.Is(err, context.Canceled) {
t.Fatalf("expected context canceled, got %v", err)
}
}
func TestCloseRelaySessionsDetachesCancellationButPreservesContextValues(t *testing.T) {
t.Parallel()
type contextKey string
const traceKey contextKey = "trace-id"
observed := make(chan any, 1)
client := &http.Client{
Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) {
observed <- req.Context().Value(traceKey)
if err := req.Context().Err(); err != nil {
return nil, fmt.Errorf("cleanup request context should not be canceled: %w", err)
}
return &http.Response{
StatusCode: http.StatusNoContent,
Body: io.NopCloser(bytes.NewReader(nil)),
}, nil
}),
}
adapter := &Adapter{httpClient: client, logger: slog.Default()}
parent, cancel := context.WithCancel(context.WithValue(context.Background(), traceKey, "trace-1"))
cancel()
adapter.closeRelaySessions(parent, map[string]string{"session-1": "http://127.0.0.1:18080"})
select {
case got := <-observed:
if got != "trace-1" {
t.Fatalf("expected context value trace-1, got %v", got)
}
case <-time.After(time.Second):
t.Fatalf("expected close relay request to be issued")
}
}
type roundTripFunc func(*http.Request) (*http.Response, error)
func (fn roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
return fn(req)
}
func TestHTTPChunkClientReadsRemoteRanges(t *testing.T) {
t.Parallel()
sourceRoot := t.TempDir()
if err := os.WriteFile(filepath.Join(sourceRoot, "weights.bin"), []byte("abcdefgh"), 0o644); err != nil {
t.Fatalf("WriteFile returned error: %v", err)
}
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/v1/exports/stat":
_ = json.NewEncoder(w).Encode(sharedtypes.StatExportResponse{SizeBytes: 8})
case "/v1/exports/readat/raw":
var req sharedtypes.ReadChunkRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
t.Fatalf("Decode returned error: %v", err)
}
data, err := os.ReadFile(filepath.Join(req.RootPath, req.RelativePath))
if err != nil {
t.Fatalf("ReadFile returned error: %v", err)
}
end := req.Offset + req.Length
if _, err := w.Write(data[req.Offset:end]); err != nil {
t.Fatalf("Write returned error: %v", err)
}
case "/v1/exports/readat":
var req sharedtypes.ReadChunkRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
t.Fatalf("Decode returned error: %v", err)
}
data, err := os.ReadFile(filepath.Join(req.RootPath, req.RelativePath))
if err != nil {
t.Fatalf("ReadFile returned error: %v", err)
}
end := req.Offset + req.Length
_ = json.NewEncoder(w).Encode(sharedtypes.ReadChunkResponse{Data: data[req.Offset:end]})
default:
http.NotFound(w, r)
}
}))
defer server.Close()
client := NewHTTPChunkClient(server.Client())
size, err := client.Stat(context.Background(), server.URL, sourceRoot, "weights.bin")
if err != nil {
t.Fatalf("Stat returned error: %v", err)
}
if size != 8 {
t.Fatalf("expected size 8, got %d", size)
}
data, err := client.ReadAt(context.Background(), server.URL, sourceRoot, "weights.bin", 2, 3)
if err != nil {
t.Fatalf("ReadAt returned error: %v", err)
}
if string(data) != "cde" {
t.Fatalf("unexpected remote data: %s", string(data))
}
}
func TestAdapterExecutesPartialPullAllGather(t *testing.T) {
t.Parallel()
sourceRoot := t.TempDir()
targetRoot := t.TempDir()
peerA := t.TempDir()
peerB := t.TempDir()
peerC := t.TempDir()
payload := []byte("abcdefghijklmnopqrstuvwxyz")
for _, dir := range []string{sourceRoot, peerA, peerB, peerC} {
if err := os.WriteFile(filepath.Join(dir, "weights.bin"), payload, 0o644); err != nil {
t.Fatalf("WriteFile returned error: %v", err)
}
}
adapter := NewAdapter(LocalChunkClient{}, nil)
result, err := adapter.Execute(context.Background(), sharedtypes.TransferSpec{
TaskID: "task-partial",
ArtifactKey: "artifact-1",
TransferMode: sharedtypes.TransferModePartialPullAllGather,
TargetTempPath: targetRoot,
ChunkSizeBytes: 4,
Parallelism: 2,
LogicalManifest: sharedtypes.LogicalManifest{
ArtifactKey: "artifact-1",
ChunkSizeBytes: 4,
Digest: "digest-1",
Files: []sharedtypes.ArtifactFile{{
RelativePath: "weights.bin",
SizeBytes: int64(len(payload)),
Kind: sharedtypes.ArtifactFileKindSafeTensors,
Chunkable: true,
Required: true,
}},
},
CollectiveSpec: sharedtypes.CollectiveSpec{
Mode: sharedtypes.CollectiveModeRing,
Ring: &sharedtypes.RingPeerPlan{
SelfNode: "node-b",
PrevNode: "node-a",
NextNode: "node-c",
Rank: 1,
WorldSize: 3,
IterationCount: 2,
},
Peers: []sharedtypes.CollectivePeerPlan{
{NodeName: "node-a", Rank: 0, StagingPath: peerA, OwnedRanges: []sharedtypes.ByteRange{{RelativePath: "weights.bin", Start: 0, End: 9}}},
{NodeName: "node-b", Rank: 1, StagingPath: peerB, OwnedRanges: []sharedtypes.ByteRange{{RelativePath: "weights.bin", Start: 9, End: 18}}},
{NodeName: "node-c", Rank: 2, StagingPath: peerC, OwnedRanges: []sharedtypes.ByteRange{{RelativePath: "weights.bin", Start: 18, End: 26}}},
},
},
SourceSegments: []sharedtypes.SourceSegmentPlan{
{
SourceID: "source-b",
SourceEndpoint: sharedtypes.SourceEndpoint{
SourceID: "source-b",
SourceType: "rdma",
Path: sourceRoot,
},
ByteRanges: []sharedtypes.ByteRange{{RelativePath: "weights.bin", Start: 9, End: 18}},
},
},
})
if err != nil {
t.Fatalf("Execute returned error: %v", err)
}
data, err := os.ReadFile(filepath.Join(targetRoot, "weights.bin"))
if err != nil {
t.Fatalf("ReadFile returned error: %v", err)
}
if string(data) != string(payload) {
t.Fatalf("unexpected target content: %s", string(data))
}
if result.TransportPath != sharedtypes.TransportPathTCPFallback {
t.Fatalf("expected transport path %q, got %q", sharedtypes.TransportPathTCPFallback, result.TransportPath)
}
if result.SolidifiedManifest == nil || len(result.SolidifiedManifest.Chunks) != 6 {
t.Fatalf("expected 6 chunk-sized entries in solidified manifest")
}
}
func TestAdapterExecutesRingCollectiveOverHTTPAPI(t *testing.T) {
t.Parallel()
sourceRoot := t.TempDir()
targetRoot := t.TempDir()
payload := []byte("abcdefghijklmnopqrstuvwxyz")
if err := os.WriteFile(filepath.Join(sourceRoot, "weights.bin"), payload, 0o644); err != nil {
t.Fatalf("WriteFile returned error: %v", err)
}
buildChunk := func(start, end int64, sourceID string) sharedtypes.CollectiveChunkPayload {
data := append([]byte(nil), payload[start:end]...)
return sharedtypes.CollectiveChunkPayload{
Chunk: sharedtypes.TransferredChunk{
ChunkID: fmt.Sprintf("%s:weights.bin:%d", sourceID, start),
RelativePath: "weights.bin",
Offset: start,
Size: end - start,
CRC32C: encodeCRC32C(crc32.Checksum(data, crc32.MakeTable(crc32.Castagnoli))),
SourceID: sourceID,
},
Data: data,
}
}
prevResponses := map[int32]sharedtypes.ListCollectiveChunksResponse{
1: {TaskID: "task-b", Iteration: 1, Chunks: []sharedtypes.CollectiveChunkPayload{buildChunk(0, 9, "node-a")}},
2: {TaskID: "task-b", Iteration: 2, Chunks: []sharedtypes.CollectiveChunkPayload{buildChunk(18, 26, "node-c")}},
}
prevServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/collectives/chunks/list" {
http.NotFound(w, r)
return
}
var req sharedtypes.ListCollectiveChunksRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
t.Fatalf("Decode returned error: %v", err)
}
resp, ok := prevResponses[req.Iteration]
if !ok {
resp = sharedtypes.ListCollectiveChunksResponse{TaskID: req.TaskID, Iteration: req.Iteration}
}
_ = json.NewEncoder(w).Encode(resp)
}))
defer prevServer.Close()
var pushedMu sync.Mutex
pushedIterations := make([]int32, 0, 2)
nextServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/collectives/chunks/push" {
http.NotFound(w, r)
return
}
var req sharedtypes.PushCollectiveChunkRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
t.Fatalf("Decode returned error: %v", err)
}
pushedMu.Lock()
pushedIterations = append(pushedIterations, req.Iteration)
pushedMu.Unlock()
_ = json.NewEncoder(w).Encode(sharedtypes.PushCollectiveChunkResponse{
TaskID: req.TaskID,
Iteration: req.Iteration,
Accepted: true,
TransportPath: sharedtypes.TransportPathTCPFallback,
})
}))
defer nextServer.Close()
adapter := NewAdapter(LocalChunkClient{}, nil)
result, err := adapter.Execute(context.Background(), sharedtypes.TransferSpec{
TaskID: "task-b",
ArtifactKey: "artifact-1",
TransferMode: sharedtypes.TransferModePartialPullAllGather,
TargetTempPath: targetRoot,
ChunkSizeBytes: 4,
Parallelism: 2,
LogicalManifest: sharedtypes.LogicalManifest{
ArtifactKey: "artifact-1",
ChunkSizeBytes: 4,
Digest: "digest-1",
Files: []sharedtypes.ArtifactFile{{
RelativePath: "weights.bin",
SizeBytes: int64(len(payload)),
Kind: sharedtypes.ArtifactFileKindSafeTensors,
Chunkable: true,
Required: true,
}},
},
CollectiveSpec: sharedtypes.CollectiveSpec{
SessionID: "collective-plan-1",
Mode: sharedtypes.CollectiveModeRing,
Ring: &sharedtypes.RingPeerPlan{
SelfNode: "node-b",
PrevNode: "node-a",
NextNode: "node-c",
PrevEndpoint: prevServer.URL,
NextEndpoint: nextServer.URL,
Rank: 1,
WorldSize: 3,
IterationCount: 2,
},
Peers: []sharedtypes.CollectivePeerPlan{
{NodeName: "node-a", Rank: 0, Endpoint: prevServer.URL, OwnedRanges: []sharedtypes.ByteRange{{RelativePath: "weights.bin", Start: 0, End: 9}}},
{NodeName: "node-b", Rank: 1, OwnedRanges: []sharedtypes.ByteRange{{RelativePath: "weights.bin", Start: 9, End: 18}}},
{NodeName: "node-c", Rank: 2, Endpoint: nextServer.URL, OwnedRanges: []sharedtypes.ByteRange{{RelativePath: "weights.bin", Start: 18, End: 26}}},
},
},
SourceSegments: []sharedtypes.SourceSegmentPlan{{
SourceID: "source-b",
SourceEndpoint: sharedtypes.SourceEndpoint{
SourceID: "source-b",
SourceType: "rdma",
Path: sourceRoot,
},
ByteRanges: []sharedtypes.ByteRange{{RelativePath: "weights.bin", Start: 9, End: 18}},
}},
})
if err != nil {
t.Fatalf("Execute returned error: %v", err)
}
data, err := os.ReadFile(filepath.Join(targetRoot, "weights.bin"))
if err != nil {
t.Fatalf("ReadFile returned error: %v", err)
}
if string(data) != string(payload) {
t.Fatalf("unexpected target content: %s", string(data))
}
if result.TransportPath != sharedtypes.TransportPathTCPFallback {
t.Fatalf("expected transport path %q, got %q", sharedtypes.TransportPathTCPFallback, result.TransportPath)
}
pushedMu.Lock()
defer pushedMu.Unlock()
if len(pushedIterations) != 2 {
t.Fatalf("expected 2 pushed iterations, got %d", len(pushedIterations))
}
}
func TestBuildCollectiveOwnerFetchSpecPreservesExistingTarget(t *testing.T) {
t.Parallel()
parent := sharedtypes.TransferSpec{
TaskID: "task-parent",
ArtifactKey: "artifact-1",
TransferMode: sharedtypes.TransferModePartialPullAllGather,
TargetTempPath: "/tmp/target",
ChunkSizeBytes: 64 << 20,
Parallelism: 4,
TimeoutSeconds: 3600,
LogicalManifest: sharedtypes.LogicalManifest{
ArtifactKey: "artifact-1",
ChunkSizeBytes: 64 << 20,
Digest: "digest-1",
Files: []sharedtypes.ArtifactFile{{
RelativePath: "weights.bin",
SizeBytes: 1024,
Kind: sharedtypes.ArtifactFileKindSafeTensors,
Chunkable: true,
Required: true,
}},
},
}
owner := sharedtypes.CollectivePeerPlan{
NodeName: "node-a",
Rank: 0,
Endpoint: "http://10.0.0.1:18080",
StagingPath: "/var/lib/weight-dispatcher/cache/staging/task-parent",
OwnedRanges: []sharedtypes.ByteRange{{
RelativePath: "weights.bin",
Start: 0,
End: 512,
}},
}
spec := buildCollectiveOwnerFetchSpec(parent, owner)
if spec.TaskID != "task-parent-owner-node-a" {
t.Fatalf("unexpected task id %q", spec.TaskID)
}
if spec.TransferMode != sharedtypes.TransferModeSingleSourceDirect {
t.Fatalf("expected single source direct mode, got %q", spec.TransferMode)
}
if !spec.PreserveExisting {
t.Fatalf("expected preserve existing to be true")
}
if spec.TargetTempPath != parent.TargetTempPath {
t.Fatalf("expected target temp path %q, got %q", parent.TargetTempPath, spec.TargetTempPath)
}
if len(spec.SourceSegments) != 1 {
t.Fatalf("expected 1 source segment, got %d", len(spec.SourceSegments))
}
if spec.SourceSegments[0].SourceEndpoint.Endpoint != owner.Endpoint {
t.Fatalf("expected endpoint %q, got %q", owner.Endpoint, spec.SourceSegments[0].SourceEndpoint.Endpoint)
}
if spec.SourceSegments[0].SourceEndpoint.Path != owner.StagingPath {
t.Fatalf("expected staging path %q, got %q", owner.StagingPath, spec.SourceSegments[0].SourceEndpoint.Path)
}
if len(spec.SourceSegments[0].ByteRanges) != 1 || spec.SourceSegments[0].ByteRanges[0].End != 512 {
t.Fatalf("unexpected byte ranges: %+v", spec.SourceSegments[0].ByteRanges)
}
}
func TestAdapterExecutesRelayPeerFetchIncrementally(t *testing.T) {
t.Parallel()
rootStaging := t.TempDir()
targetRoot := t.TempDir()
payload := []byte("abcdefghijklmnopqrstuvwxyz")
rootFile := filepath.Join(rootStaging, "weights.bin")
if err := os.WriteFile(rootFile, payload, 0o644); err != nil {
t.Fatalf("WriteFile returned error: %v", err)
}
var listMu sync.Mutex
listCalls := 0
rootServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/collectives/chunks/list" {
http.NotFound(w, r)
return
}
var req sharedtypes.ListCollectiveChunksRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
t.Fatalf("Decode returned error: %v", err)
}
listMu.Lock()
listCalls++
currentCall := listCalls
listMu.Unlock()
var chunks []sharedtypes.CollectiveChunkPayload
switch req.Iteration {
case 1:
chunks = []sharedtypes.CollectiveChunkPayload{{
Chunk: sharedtypes.TransferredChunk{
ChunkID: "node-root:weights.bin:0",
RelativePath: "weights.bin",
Offset: 0,
Size: 13,
SourceID: "node-root",
},
}}
case 2:
chunks = []sharedtypes.CollectiveChunkPayload{{
Chunk: sharedtypes.TransferredChunk{
ChunkID: "node-root:weights.bin:13",
RelativePath: "weights.bin",
Offset: 13,
Size: 13,
SourceID: "node-root",
},
}}
default:
chunks = nil
}
_ = currentCall
_ = json.NewEncoder(w).Encode(sharedtypes.ListCollectiveChunksResponse{
TaskID: req.TaskID,
Iteration: req.Iteration,
Chunks: chunks,
})
}))
defer rootServer.Close()
adapter := NewAdapter(LocalChunkClient{}, nil)
result, err := adapter.Execute(context.Background(), sharedtypes.TransferSpec{
TaskID: "relay-peer-task",
ArtifactKey: "artifact-1",
TransferMode: sharedtypes.TransferModePartialPullAllGather,
TargetTempPath: targetRoot,
ChunkSizeBytes: 13,
Parallelism: 1,
LogicalManifest: sharedtypes.LogicalManifest{
ArtifactKey: "artifact-1",
ChunkSizeBytes: 13,
Digest: "digest-1",
Files: []sharedtypes.ArtifactFile{{
RelativePath: "weights.bin",
SizeBytes: int64(len(payload)),
Kind: sharedtypes.ArtifactFileKindSafeTensors,
Chunkable: true,
Required: true,
}},
},
CollectiveSpec: sharedtypes.CollectiveSpec{
SessionID: "relay-session",
Mode: sharedtypes.CollectiveModeRing,
Ring: &sharedtypes.RingPeerPlan{
SelfNode: "node-peer",
PrevNode: "node-root",
NextNode: "node-root",
PrevEndpoint: rootServer.URL,
NextEndpoint: rootServer.URL,
Rank: 1,
WorldSize: 2,
},
Peers: []sharedtypes.CollectivePeerPlan{
{
NodeName: "node-root",
Endpoint: rootServer.URL,
Rank: 0,
StagingPath: rootStaging,
OwnedRanges: []sharedtypes.ByteRange{
{RelativePath: "weights.bin", Start: 0, End: 13},
{RelativePath: "weights.bin", Start: 13, End: 26},
},
},
{NodeName: "node-peer", Endpoint: "http://peer", Rank: 1},
},
},
})
if err != nil {
t.Fatalf("Execute returned error: %v", err)
}
data, err := os.ReadFile(filepath.Join(targetRoot, "weights.bin"))
if err != nil {
t.Fatalf("ReadFile returned error: %v", err)
}
if string(data) != string(payload) {
t.Fatalf("unexpected target content: %s", string(data))
}
if result.BytesTransferred != int64(len(payload)) {
t.Fatalf("expected %d bytes transferred, got %d", len(payload), result.BytesTransferred)
}
}
func TestResolveSymmetricFanoutRole(t *testing.T) {
t.Parallel()
spec := sharedtypes.TransferSpec{
TransferMode: sharedtypes.TransferModePartialPullAllGather,
CollectiveSpec: sharedtypes.CollectiveSpec{
Mode: sharedtypes.CollectiveModeRing,
Ring: &sharedtypes.RingPeerPlan{
SelfNode: "node-a",
SelfEndpoint: "http://node-a",
Rank: 0,
WorldSize: 2,
},
Peers: []sharedtypes.CollectivePeerPlan{
{
NodeName: "node-a",
Endpoint: "http://node-a",
Rank: 0,
StagingPath: "/staging/a",
OwnedRanges: []sharedtypes.ByteRange{{RelativePath: "weights.bin", Start: 0, End: 13}},
},
{
NodeName: "node-b",
Endpoint: "http://node-b",
Rank: 1,
StagingPath: "/staging/b",
OwnedRanges: []sharedtypes.ByteRange{{RelativePath: "weights.bin", Start: 13, End: 26}},
},
},
},
}
role, ok := resolveSymmetricFanoutRole(spec)
if !ok {
t.Fatalf("expected symmetric fanout role to resolve")
}
if role.selfPeer.NodeName != "node-a" {
t.Fatalf("unexpected self peer %q", role.selfPeer.NodeName)
}
if role.otherPeer.NodeName != "node-b" {
t.Fatalf("unexpected other peer %q", role.otherPeer.NodeName)
}
}
func TestSymmetricFanoutSourceParallelism(t *testing.T) {
spec := sharedtypes.TransferSpec{
TransferMode: sharedtypes.TransferModePartialPullAllGather,
Parallelism: 2,
CollectiveSpec: sharedtypes.CollectiveSpec{
Mode: sharedtypes.CollectiveModeRing,
Ring: &sharedtypes.RingPeerPlan{
SelfNode: "node-a",
SelfEndpoint: "http://node-a:8081",
Rank: 0,
WorldSize: 2,
},
Peers: []sharedtypes.CollectivePeerPlan{
{
NodeName: "node-a",
Rank: 0,
Endpoint: "http://node-a:8081",
OwnedRanges: []sharedtypes.ByteRange{{RelativePath: "blob4g.safetensors", Start: 0, End: 2 * 1024 * 1024 * 1024}},
},
{
NodeName: "node-b",
Rank: 1,
Endpoint: "http://node-b:8081",
OwnedRanges: []sharedtypes.ByteRange{{RelativePath: "blob4g.safetensors", Start: 2 * 1024 * 1024 * 1024, End: 4 * 1024 * 1024 * 1024}},
},
},
},
}
if got := symmetricFanoutSourceParallelism(spec); got != 4 {
t.Fatalf("expected 2 GiB owned range to derive source parallelism 4, got %d", got)
}
spec.CollectiveSpec.Peers[0].OwnedRanges[0].End = 2 * relayFanoutBatchTargetBytes
if got := symmetricFanoutSourceParallelism(spec); got != 2 {
t.Fatalf("expected 1 GiB owned range to derive source parallelism 2, got %d", got)
}
t.Setenv("WD_RDMA_SYMMETRIC_SOURCE_PARALLELISM", "3")
if got := symmetricFanoutSourceParallelism(sharedtypes.TransferSpec{}); got != 3 {
t.Fatalf("expected env override 3, got %d", got)
}
t.Setenv("WD_RDMA_SYMMETRIC_SOURCE_PARALLELISM", "99")
if got := symmetricFanoutSourceParallelism(sharedtypes.TransferSpec{}); got != 8 {
t.Fatalf("expected env override to clamp at 8, got %d", got)
}
}
func TestResolveNPeerFanoutSpecOrdersPeersByRingStep(t *testing.T) {
t.Parallel()
spec := sharedtypes.TransferSpec{
TransferMode: sharedtypes.TransferModePartialPullAllGather,
CollectiveSpec: sharedtypes.CollectiveSpec{
Mode: sharedtypes.CollectiveModeRing,
Ring: &sharedtypes.RingPeerPlan{
SelfNode: "node-b",
SelfEndpoint: "http://node-b:8081",
Rank: 1,
WorldSize: 4,
},
Peers: []sharedtypes.CollectivePeerPlan{
{NodeName: "node-a", Endpoint: "http://node-a:8081", Rank: 0, OwnedRanges: []sharedtypes.ByteRange{{RelativePath: "weights.bin", Start: 0, End: 1}}},
{NodeName: "node-b", Endpoint: "http://node-b:8081", Rank: 1, OwnedRanges: []sharedtypes.ByteRange{{RelativePath: "weights.bin", Start: 1, End: 2}}},
{NodeName: "node-c", Endpoint: "http://node-c:8081", Rank: 2, OwnedRanges: []sharedtypes.ByteRange{{RelativePath: "weights.bin", Start: 2, End: 3}}},
{NodeName: "node-d", Endpoint: "http://node-d:8081", Rank: 3, OwnedRanges: []sharedtypes.ByteRange{{RelativePath: "weights.bin", Start: 3, End: 4}}},
},
},
}
np, ok := resolveNPeerFanoutSpec(spec)
if !ok {
t.Fatalf("expected n-peer fanout spec to resolve")
}
if np.selfPeer.NodeName != "node-b" {
t.Fatalf("unexpected self peer %q", np.selfPeer.NodeName)
}
if len(np.otherPeers) != 3 {
t.Fatalf("expected 3 peer fetch targets, got %d", len(np.otherPeers))
}
got := []string{np.otherPeers[0].NodeName, np.otherPeers[1].NodeName, np.otherPeers[2].NodeName}
want := []string{"node-a", "node-d", "node-c"}
for i := range want {
if got[i] != want[i] {
t.Fatalf("unexpected ring peer order: got %v want %v", got, want)
}
}
}
func TestResolveNPeerFanoutSpecIncludesTwoPeerRing(t *testing.T) {
t.Parallel()
spec := sharedtypes.TransferSpec{
TransferMode: sharedtypes.TransferModePartialPullAllGather,
CollectiveSpec: sharedtypes.CollectiveSpec{
Mode: sharedtypes.CollectiveModeRing,
Ring: &sharedtypes.RingPeerPlan{
SelfNode: "node-a",
SelfEndpoint: "http://node-a:8081",
Rank: 0,
WorldSize: 2,
},
Peers: []sharedtypes.CollectivePeerPlan{
{NodeName: "node-a", Endpoint: "http://node-a:8081", Rank: 0, OwnedRanges: []sharedtypes.ByteRange{{RelativePath: "weights.bin", Start: 0, End: 8}}},
{NodeName: "node-b", Endpoint: "http://node-b:8081", Rank: 1, OwnedRanges: []sharedtypes.ByteRange{{RelativePath: "weights.bin", Start: 8, End: 16}}},
},
},
}
np, ok := resolveNPeerFanoutSpec(spec)
if !ok {
t.Fatalf("expected two-peer ring to resolve through n-peer path")
}
if np.selfPeer.NodeName != "node-a" {
t.Fatalf("unexpected self peer %q", np.selfPeer.NodeName)
}
if len(np.otherPeers) != 1 || np.otherPeers[0].NodeName != "node-b" {
t.Fatalf("unexpected peer order: %+v", np.otherPeers)
}
if got := ringOwnerStep(2, 0, 1); got != 1 {
t.Fatalf("expected ring step 1 for the only remote owner, got %d", got)
}
}
func TestNPeerFanoutSourceAndPeerFetchNoopPaths(t *testing.T) {
t.Parallel()
adapter := NewAdapterWithOptions(LocalChunkClient{}, nil, AdapterOptions{EnableFFI: false, ForceTCPFallback: true})
spec := sharedtypes.TransferSpec{
TaskID: "npeer-noop",
TransferMode: sharedtypes.TransferModePartialPullAllGather,
TargetTempPath: t.TempDir(),
Parallelism: 4,
LogicalManifest: sharedtypes.LogicalManifest{Files: []sharedtypes.ArtifactFile{{
RelativePath: "model.bin",
SizeBytes: 8,
Chunkable: true,
}}},
SourceSegments: []sharedtypes.SourceSegmentPlan{{
SourceEndpoint: sharedtypes.SourceEndpoint{Path: t.TempDir()},
}},
CollectiveSpec: sharedtypes.CollectiveSpec{
SessionID: "session",
Ring: &sharedtypes.RingPeerPlan{Rank: 1, WorldSize: 3, SelfNode: "rank1", SelfEndpoint: "http://rank1:18080"},
Peers: []sharedtypes.CollectivePeerPlan{
{NodeName: "rank0", Rank: 0, Endpoint: "http://rank0:18080"},
{NodeName: "rank1", Rank: 1, Endpoint: "http://rank1:18080", OwnedRanges: []sharedtypes.ByteRange{{RelativePath: "model.bin", Start: 0, End: 8}}},
{NodeName: "rank2", Rank: 2, Endpoint: "http://rank2:18080"},
},
},
}
np, ok := resolveNPeerFanoutSpec(spec)
if !ok {
t.Fatalf("expected n-peer spec to resolve")
}
selfSpec, ownsSource := adapter.buildNPeerSourceSpec(spec, np)
if !ownsSource {
t.Fatalf("expected self to own source phase")
}
if selfSpec.TransferMode != sharedtypes.TransferModeSingleSourceDirect || selfSpec.Parallelism != 4 {
t.Fatalf("unexpected source spec: %#v", selfSpec)
}
emptyNP := nPeerFanoutSpec{selfPeer: sharedtypes.CollectivePeerPlan{Rank: 1}}
if err := adapter.waitForNPeerSourcePhase(context.Background(), spec, emptyNP, false); err != nil {
t.Fatalf("waitForNPeerSourcePhase no-op returned error: %v", err)
}
sourceResult, sourcePath, err := adapter.executeNPeerSourcePhase(context.Background(), spec, selfSpec, emptyNP, false)
if err != nil {
t.Fatalf("executeNPeerSourcePhase no-op returned error: %v", err)
}
if sourceResult.BytesTransferred != 0 || sourcePath != "" {
t.Fatalf("expected empty source phase result, got %#v %q", sourceResult, sourcePath)
}
merged, transportPath, err := adapter.executeNPeerPeerFetches(
context.Background(),
func() {},
spec,
nPeerFanoutSpec{otherPeers: []sharedtypes.CollectivePeerPlan{{NodeName: "rank0", Rank: 0}}},
sharedtypes.TransferResult{BytesTransferred: 9},
sharedtypes.TransportPathRDMA,
)
if err != nil {
t.Fatalf("executeNPeerPeerFetches no-op returned error: %v", err)
}
if merged.BytesTransferred != 9 || transportPath != sharedtypes.TransportPathRDMA {
t.Fatalf("unexpected merged no-op result: %#v path=%q", merged, transportPath)
}
}
func TestPublishNPeerSourceDonePushesChunkMetadataWithoutData(t *testing.T) {
t.Parallel()
var received sharedtypes.PushCollectiveChunkRequest
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/collectives/chunks/push" {
t.Fatalf("unexpected path %s", r.URL.Path)
}
if err := json.NewDecoder(r.Body).Decode(&received); err != nil {
t.Fatalf("Decode returned error: %v", err)
}
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
adapter := NewAdapter(LocalChunkClient{}, nil)
spec := sharedtypes.TransferSpec{
TaskID: "npeer-publish",
ArtifactKey: "artifact",
CollectiveSpec: sharedtypes.CollectiveSpec{
SessionID: "session",
Ring: &sharedtypes.RingPeerPlan{SelfEndpoint: server.URL, SelfNode: "rank1"},
},
}
chunks := []sharedtypes.TransferredChunk{{
RelativePath: "model.bin",
Offset: 0,
Size: 8,
CRC32C: "abcd",
}}
if err := adapter.publishNPeerSourceDone(context.Background(), spec, sharedtypes.CollectivePeerPlan{Rank: 1}, chunks); err != nil {
t.Fatalf("publishNPeerSourceDone returned error: %v", err)
}
if received.TaskID != "npeer-publish" || received.SessionID != "session" || received.SourceNode != "rank1" {
t.Fatalf("unexpected push metadata: %#v", received)
}
if len(received.Data) != 0 {
t.Fatalf("expected source-done metadata push without payload data")
}
if received.Chunk.RelativePath != "model.bin" || received.Chunk.CRC32C != "abcd" {
t.Fatalf("unexpected pushed chunk: %#v", received.Chunk)
}
}
func TestDirectoryPerFileParallelismClampsSingleSourceDirectories(t *testing.T) {
t.Parallel()
spec := sharedtypes.TransferSpec{
TransferMode: sharedtypes.TransferModeSingleSourceDirect,
Parallelism: 4,
LogicalManifest: sharedtypes.LogicalManifest{
Files: []sharedtypes.ArtifactFile{
{RelativePath: "model-00001.safetensors", Chunkable: true},
{RelativePath: "model-00002.safetensors", Chunkable: true},
},
},
}
file := spec.LogicalManifest.Files[0]
if got := directoryPerFileParallelism(spec, file); got != 1 {
t.Fatalf("expected full-dir single-source chunkable file to clamp to 1, got %d", got)
}
spec.TransferMode = sharedtypes.TransferModeDirectStriped
if got := directoryPerFileParallelism(spec, file); got != 4 {
t.Fatalf("expected striped directory file to preserve parallelism, got %d", got)
}
spec.TransferMode = sharedtypes.TransferModeSingleSourceDirect
spec.LogicalManifest.Files = []sharedtypes.ArtifactFile{{RelativePath: "single.safetensors", Chunkable: true}}
if got := directoryPerFileParallelism(spec, spec.LogicalManifest.Files[0]); got != 4 {
t.Fatalf("expected single-file transfer to preserve parallelism, got %d", got)
}
spec.LogicalManifest.Files = []sharedtypes.ArtifactFile{
{RelativePath: "model-00001.safetensors", Chunkable: true},
{RelativePath: "model-00002.safetensors", Chunkable: true},
}
spec.SourceSegments = []sharedtypes.SourceSegmentPlan{{
SourceEndpoint: sharedtypes.SourceEndpoint{
SourceType: "external",
Endpoint: "https://huggingface.co#rev=main",
Path: "Qwen/Qwen3-32B",
},
}}
if got := directoryPerFileParallelism(spec, spec.LogicalManifest.Files[0]); got != 8 {
t.Fatalf("expected huggingface directory file to widen to 8 lanes, got %d", got)
}
}
func TestDirectoryPerFileFanoutPrefersPeerStagingOwnerFetch(t *testing.T) {
t.Parallel()
spec := sharedtypes.TransferSpec{
TransferMode: sharedtypes.TransferModePartialPullAllGather,
PreserveExisting: true,
LogicalManifest: sharedtypes.LogicalManifest{
Files: []sharedtypes.ArtifactFile{{
RelativePath: "model-00001-of-00017.safetensors",
Chunkable: true,
}},
},
CollectiveSpec: sharedtypes.CollectiveSpec{
Mode: sharedtypes.CollectiveModeRing,
Ring: &sharedtypes.RingPeerPlan{
SelfNode: "n0",
SelfEndpoint: "http://n0:18080",
Rank: 0,
WorldSize: 2,
},
},
SourceSegments: []sharedtypes.SourceSegmentPlan{{
SourceEndpoint: sharedtypes.SourceEndpoint{
SourceType: "node",
NodeName: "m0",
},
}},
}
if !isDirectoryPerFileFanoutSubTask(spec) {
t.Fatalf("expected spec to be recognized as a directory-per-file fanout sub-task")
}
if !useDirectPeerStagingOwnerFetch(spec) {
t.Fatalf("expected directory-per-file fanout to prefer peer staging owner-fetch")
}
}
func TestBuildPerFileTransferSpecKeepsCollectiveButNPeerSourcePhaseWillDowngrade(t *testing.T) {
t.Parallel()
spec := sharedtypes.TransferSpec{
TransferMode: sharedtypes.TransferModePartialPullAllGather,
PreserveExisting: true,
LogicalManifest: sharedtypes.LogicalManifest{
Files: []sharedtypes.ArtifactFile{{
RelativePath: "model-00001-of-00017.safetensors",
SizeBytes: 3900791800,
Chunkable: true,
}},
},
CollectiveSpec: sharedtypes.CollectiveSpec{
Mode: sharedtypes.CollectiveModeRing,
Ring: &sharedtypes.RingPeerPlan{
SelfNode: "n0",
SelfEndpoint: "http://n0:18080",
Rank: 0,
WorldSize: 2,
},
Peers: []sharedtypes.CollectivePeerPlan{
{NodeName: "n0", Rank: 0, Endpoint: "http://n0:18080", StagingPath: "/cache/n0", OwnedRanges: []sharedtypes.ByteRange{{RelativePath: "model-00001-of-00017.safetensors", Start: 0, End: 1950395900}}},
{NodeName: "n2", Rank: 1, Endpoint: "http://n2:18080", StagingPath: "/cache/n2", OwnedRanges: []sharedtypes.ByteRange{{RelativePath: "model-00001-of-00017.safetensors", Start: 1950395900, End: 3900791800}}},
},
},
SourceSegments: []sharedtypes.SourceSegmentPlan{{
SourceEndpoint: sharedtypes.SourceEndpoint{
SourceType: "node",
NodeName: "m0",
Endpoint: "http://m0:18080",
Path: "/src/qwen32b",
},
ByteRanges: []sharedtypes.ByteRange{{
RelativePath: "model-00001-of-00017.safetensors",
Start: 0,
End: 1950395900,
}},
}},
}
file := spec.LogicalManifest.Files[0]
subSpec := buildPerFileTransferSpec(spec, file)
if subSpec.TransferMode != sharedtypes.TransferModePartialPullAllGather {
t.Fatalf("directory-per-file fanout subtask should still be collective before source-phase rewrite, got %s", subSpec.TransferMode)
}
if !useDirectPeerStagingOwnerFetch(subSpec) {
t.Fatalf("expected per-file fanout subtask to prefer direct peer staging owner-fetch")
}
np, ok := resolveNPeerFanoutSpec(subSpec)
if !ok {
t.Fatalf("expected n-peer fanout spec to resolve")
}
selfSpec := subSpec
selfSpec.PreserveExisting = true
selfOwnsSource := countValidRanges(np.selfPeer.OwnedRanges) > 0 || len(selfSpec.SourceSegments) > 0
if !selfOwnsSource {
t.Fatalf("expected self peer to own source-half")
}
if selfOwnsSource && useDirectPeerStagingOwnerFetch(subSpec) {
selfSpec.TransferMode = sharedtypes.TransferModeSingleSourceDirect
selfSpec.CollectiveSpec = sharedtypes.CollectiveSpec{}
}
if selfSpec.TransferMode != sharedtypes.TransferModeSingleSourceDirect {
t.Fatalf("expected source-phase to downgrade to single-source direct, got %s", selfSpec.TransferMode)
}
if selfSpec.CollectiveSpec.Ring != nil || len(selfSpec.CollectiveSpec.Peers) != 0 {
t.Fatalf("expected source-phase collective spec to be cleared")
}
}
func TestWaitForPreviousDirectoryFanoutFileUsesLastRankDoneMarker(t *testing.T) {
t.Parallel()
file := sharedtypes.ArtifactFile{RelativePath: "model-00013-of-00017.safetensors", SizeBytes: 1024, Chunkable: true}
spec := sharedtypes.TransferSpec{
TaskID: "fanout-file-013",
TransferMode: sharedtypes.TransferModePartialPullAllGather,
TargetTempPath: "/tmp/target",
LogicalManifest: sharedtypes.LogicalManifest{
Files: []sharedtypes.ArtifactFile{file},
},
CollectiveSpec: sharedtypes.CollectiveSpec{
Mode: sharedtypes.CollectiveModeRing,
SessionID: "session-13",
Ring: &sharedtypes.RingPeerPlan{
Rank: 0,
WorldSize: 2,
},
Peers: []sharedtypes.CollectivePeerPlan{
{NodeName: "n0", Rank: 0, Endpoint: "http://n0:18080", StagingPath: "/var/lib/weight/cache/staging/n0"},
{NodeName: "n2", Rank: 1, Endpoint: "http://n2:18080", StagingPath: "/var/lib/weight/cache/staging/n2"},
},
},
SourceSegments: []sharedtypes.SourceSegmentPlan{{
SourceEndpoint: sharedtypes.SourceEndpoint{SourceType: "node", NodeName: "m0"},
ByteRanges: []sharedtypes.ByteRange{{RelativePath: file.RelativePath, Start: 0, End: file.SizeBytes}},
}},
}
client := fakeChunkClient{stats: map[string]int64{
fmt.Sprintf("%s|%s|%s", "http://n2:18080", "/var/lib/weight/cache/staging/n2", fanoutStagingDoneMarkerRelativePath(file.RelativePath)): 8,
}}
if err := waitForPreviousDirectoryFanoutFile(context.Background(), client, spec, file); err != nil {
t.Fatalf("expected previous-file barrier to succeed, got %v", err)
}
}
func TestFanoutMarkerHelpersWriteAndWait(t *testing.T) {
t.Parallel()
staging := t.TempDir()
relativePath := "nested/model.safetensors"
spec := sharedtypes.TransferSpec{
TaskID: "fanout-file-001",
TransferMode: sharedtypes.TransferModePartialPullAllGather,
TargetTempPath: staging,
LogicalManifest: sharedtypes.LogicalManifest{
Files: []sharedtypes.ArtifactFile{{RelativePath: relativePath, SizeBytes: 128, Chunkable: true}},
},
CollectiveSpec: sharedtypes.CollectiveSpec{
Ring: &sharedtypes.RingPeerPlan{Rank: 1, WorldSize: 2},
Peers: []sharedtypes.CollectivePeerPlan{
{NodeName: "n0", Rank: 0, Endpoint: "local", StagingPath: staging},
{NodeName: "n2", Rank: 1, Endpoint: "local", StagingPath: staging},
},
},
SourceSegments: []sharedtypes.SourceSegmentPlan{{
ByteRanges: []sharedtypes.ByteRange{{RelativePath: relativePath, Start: 0, End: 128}},
}},
}
if err := markFanoutSourceHalfReady(spec); err != nil {
t.Fatalf("markFanoutSourceHalfReady returned error: %v", err)
}
if err := markFanoutFileDone(spec); err != nil {
t.Fatalf("markFanoutFileDone returned error: %v", err)
}
client := LocalChunkClient{}
peer := sharedtypes.CollectivePeerPlan{NodeName: "n0", Endpoint: "local", StagingPath: staging}
ranges := []sharedtypes.ByteRange{{RelativePath: relativePath, Start: 0, End: 128}}
if err := waitForPeerFanoutSourceHalfReady(context.Background(), client, spec, peer, ranges); err != nil {
t.Fatalf("waitForPeerFanoutSourceHalfReady returned error: %v", err)
}
if err := waitForPeerFanoutFileDone(context.Background(), client, spec, peer, relativePath); err != nil {
t.Fatalf("waitForPeerFanoutFileDone returned error: %v", err)
}
}
func TestWaitForNPeerSourceTurnWaitsOnPredecessorMarker(t *testing.T) {
t.Parallel()
staging := t.TempDir()
relativePath := "model.bin"
spec := sharedtypes.TransferSpec{
TaskID: "npeer-file-002",
TransferMode: sharedtypes.TransferModePartialPullAllGather,
TargetTempPath: staging,
LogicalManifest: sharedtypes.LogicalManifest{
Files: []sharedtypes.ArtifactFile{{RelativePath: relativePath, SizeBytes: 64, Chunkable: true}},
},
CollectiveSpec: sharedtypes.CollectiveSpec{
Ring: &sharedtypes.RingPeerPlan{Rank: 1, WorldSize: 3},
Peers: []sharedtypes.CollectivePeerPlan{
{NodeName: "rank0", Rank: 0, Endpoint: "local", StagingPath: staging},
{NodeName: "rank1", Rank: 1, Endpoint: "local", StagingPath: staging},
},
},
SourceSegments: []sharedtypes.SourceSegmentPlan{{
ByteRanges: []sharedtypes.ByteRange{{RelativePath: relativePath, Start: 0, End: 64}},
}},
}
if err := markFanoutSourceHalfReady(spec); err != nil {
t.Fatalf("markFanoutSourceHalfReady returned error: %v", err)
}
np := nPeerFanoutSpec{
selfPeer: sharedtypes.CollectivePeerPlan{
NodeName: "rank1",
Rank: 1,
OwnedRanges: []sharedtypes.ByteRange{{
RelativePath: relativePath,
Start: 0,
End: 64,
}},
},
otherPeers: []sharedtypes.CollectivePeerPlan{
{NodeName: "rank0", Rank: 0, Endpoint: "local", StagingPath: staging},
{NodeName: "rank1", Rank: 1, Endpoint: "local", StagingPath: staging},
},
}
if err := waitForNPeerSourceTurn(context.Background(), LocalChunkClient{}, spec, np); err != nil {
t.Fatalf("waitForNPeerSourceTurn returned error: %v", err)
}
}
func TestTCPDirectoryFanoutSourceTurnDefaultsToSerial(t *testing.T) {
t.Setenv("WD_TCP_FANOUT_SERIAL_SOURCE_TURN", "")
spec := sharedtypes.TransferSpec{
TransferMode: sharedtypes.TransferModePartialPullAllGather,
PreserveExisting: true,
LogicalManifest: sharedtypes.LogicalManifest{Files: []sharedtypes.ArtifactFile{{
RelativePath: "model.bin",
Chunkable: true,
}}},
}
if !shouldSerializeNPeerSourceTurn(spec, true) {
t.Fatalf("TCP directory fanout should serialize source turns by default")
}
if !shouldSerializeNPeerSourceTurn(spec, false) {
t.Fatalf("RDMA directory fanout should keep serialized source turns")
}
t.Setenv("WD_TCP_FANOUT_SERIAL_SOURCE_TURN", "false")
if shouldSerializeNPeerSourceTurn(spec, true) {
t.Fatalf("TCP directory fanout should allow env override to concurrent source turns")
}
}
func TestWaitForPeerRequiredFileSizes(t *testing.T) {
t.Parallel()
staging := t.TempDir()
if err := os.WriteFile(filepath.Join(staging, "model.bin"), []byte("ready"), 0o644); err != nil {
t.Fatalf("WriteFile returned error: %v", err)
}
adapter := NewAdapter(LocalChunkClient{}, nil)
peer := sharedtypes.CollectivePeerPlan{NodeName: "n0", Endpoint: "local", StagingPath: staging}
required := map[string]int64{"model.bin": 5}
if err := adapter.waitForPeerRequiredFileSizes(context.Background(), peer, required, time.Second, "file"); err != nil {
t.Fatalf("waitForPeerRequiredFileSizes returned error: %v", err)
}
}
func TestDirectoryFanoutWindowDefaultsToTwoForCollectiveDirectories(t *testing.T) {
t.Parallel()
spec := sharedtypes.TransferSpec{
TransferMode: sharedtypes.TransferModePartialPullAllGather,
CollectiveSpec: sharedtypes.CollectiveSpec{
Ring: &sharedtypes.RingPeerPlan{Rank: 0, WorldSize: 2},
},
}
files := []sharedtypes.ArtifactFile{
{RelativePath: "a", Chunkable: true},
{RelativePath: "b", Chunkable: true},
{RelativePath: "c", Chunkable: true},
}
if got := directoryFanoutWindow(spec, files); got != 2 {
t.Fatalf("expected default fanout window 2, got %d", got)
}
}
func TestDirectoryFanoutWindowHonorsEnvOverride(t *testing.T) {
t.Setenv("WD_DIRECTORY_FANOUT_WINDOW", "3")
spec := sharedtypes.TransferSpec{
TransferMode: sharedtypes.TransferModePartialPullAllGather,
CollectiveSpec: sharedtypes.CollectiveSpec{
Ring: &sharedtypes.RingPeerPlan{Rank: 0, WorldSize: 3},
},
}
files := []sharedtypes.ArtifactFile{
{RelativePath: "a", Chunkable: true},
{RelativePath: "b", Chunkable: true},
{RelativePath: "c", Chunkable: true},
{RelativePath: "d", Chunkable: true},
}
if got := directoryFanoutWindow(spec, files); got != 3 {
t.Fatalf("expected overridden fanout window 3, got %d", got)
}
}
func TestShouldPipelineDirectoryFanoutOnlyForTCPFallback(t *testing.T) {
files := []sharedtypes.ArtifactFile{
{RelativePath: "a.bin", Chunkable: true},
{RelativePath: "b.bin", Chunkable: true},
}
spec := sharedtypes.TransferSpec{
TransferMode: sharedtypes.TransferModePartialPullAllGather,
CollectiveSpec: sharedtypes.CollectiveSpec{
Ring: &sharedtypes.RingPeerPlan{Rank: 0, WorldSize: 2},
},
}
if !shouldPipelineDirectoryFanout(spec, files, true) {
t.Fatalf("expected TCP fallback directory fanout to use pipeline")
}
if shouldPipelineDirectoryFanout(spec, files, false) {
t.Fatalf("RDMA directory fanout should keep the stable non-pipelined path")
}
if shouldPipelineDirectoryFanout(spec, files[:1], true) {
t.Fatalf("single-file fanout should not use directory pipeline")
}
if shouldPipelineDirectoryFanout(sharedtypes.TransferSpec{TransferMode: sharedtypes.TransferModeDirectStriped}, files, true) {
t.Fatalf("non-fanout directory transfer should not use fanout pipeline")
}
}
func TestTCPDirectoryFanoutPipelineDepthDefaultsAndEnv(t *testing.T) {
files := []sharedtypes.ArtifactFile{
{RelativePath: "a.bin", Chunkable: true},
{RelativePath: "b.bin", Chunkable: true},
{RelativePath: "c.bin", Chunkable: true},
{RelativePath: "d.bin", Chunkable: true},
}
manyFiles := make([]sharedtypes.ArtifactFile, 12)
for idx := range manyFiles {
manyFiles[idx] = sharedtypes.ArtifactFile{RelativePath: fmt.Sprintf("file-%02d.bin", idx), Chunkable: true}
}
if got := tcpDirectoryFanoutSourceAhead(files); got != len(files) {
t.Fatalf("expected default TCP source-ahead to clamp to file count, got %d", got)
}
if got := tcpDirectoryFanoutExchangeWorkers(files); got != len(files) {
t.Fatalf("expected default TCP exchange workers to clamp to file count, got %d", got)
}
if got := tcpDirectoryFanoutSourceAhead(manyFiles); got != defaultTCPFanoutSourceAhead {
t.Fatalf("expected default TCP source-ahead %d, got %d", defaultTCPFanoutSourceAhead, got)
}
if got := tcpDirectoryFanoutExchangeWorkers(manyFiles); got != defaultTCPFanoutExchangeWorkers {
t.Fatalf("expected default TCP exchange workers %d, got %d", defaultTCPFanoutExchangeWorkers, got)
}
t.Setenv("WD_TCP_FANOUT_SOURCE_AHEAD", "8")
t.Setenv("WD_TCP_FANOUT_EXCHANGE_WORKERS", "9")
if got := tcpDirectoryFanoutSourceAhead(files); got != len(files) {
t.Fatalf("source-ahead should clamp to file count, got %d", got)
}
if got := tcpDirectoryFanoutExchangeWorkers(files); got != len(files) {
t.Fatalf("exchange workers should clamp to file count, got %d", got)
}
t.Setenv("WD_TCP_FANOUT_SOURCE_AHEAD", "bad")
t.Setenv("WD_TCP_FANOUT_EXCHANGE_WORKERS", "0")
if got := tcpDirectoryFanoutSourceAhead(files[:1]); got != 1 {
t.Fatalf("single-file source-ahead should clamp to 1, got %d", got)
}
if got := tcpDirectoryFanoutExchangeWorkers(files); got != 1 {
t.Fatalf("invalid worker env should clamp to 1, got %d", got)
}
}
func TestPerFileHelperAdditionalBranches(t *testing.T) {
files := []sharedtypes.ArtifactFile{
{RelativePath: "a.bin", Chunkable: true},
{RelativePath: "b.bin", Chunkable: true},
}
if shouldExecuteDirectoryPerFile(sharedtypes.TransferSpec{TransferMode: sharedtypes.TransferModeSingleSourceDirect}, files[:1]) {
t.Fatalf("single file should not execute directory-per-file path")
}
if shouldExecuteDirectoryPerFile(sharedtypes.TransferSpec{TransferMode: sharedtypes.TransferMode("")}, files) {
t.Fatalf("unsupported transfer mode should not execute directory-per-file path")
}
if !shouldExecuteDirectoryPerFile(sharedtypes.TransferSpec{TransferMode: sharedtypes.TransferModeDirectStriped}, files) {
t.Fatalf("striped directory should execute directory-per-file path")
}
collectiveSpec := sharedtypes.TransferSpec{
TransferMode: sharedtypes.TransferModePartialPullAllGather,
CollectiveSpec: sharedtypes.CollectiveSpec{
Ring: &sharedtypes.RingPeerPlan{Rank: 0, WorldSize: 2},
},
}
if got := directoryFanoutWindow(collectiveSpec, files[:1]); got != 1 {
t.Fatalf("single file fanout window should be 1, got %d", got)
}
if got := directoryFanoutWindow(sharedtypes.TransferSpec{TransferMode: sharedtypes.TransferModeDirectStriped}, files); got != 1 {
t.Fatalf("non-fanout directory window should be 1, got %d", got)
}
t.Setenv("WD_DIRECTORY_FANOUT_WINDOW", "bad")
if got := directoryFanoutWindow(collectiveSpec, files); got != 2 {
t.Fatalf("invalid env should use default fanout window, got %d", got)
}
t.Setenv("WD_DIRECTORY_FANOUT_WINDOW", "0")
if got := directoryFanoutWindow(collectiveSpec, files); got != 1 {
t.Fatalf("zero env should clamp to 1, got %d", got)
}
t.Setenv("WD_DIRECTORY_FANOUT_WINDOW", "9")
if got := directoryFanoutWindow(collectiveSpec, files); got != len(files) {
t.Fatalf("large env should clamp to file count, got %d", got)
}
envKey := "WD_TEST_MAX_PARALLELISM"
if value, ok := maxConfiguredParallelism(envKey, 2); ok || value != 0 {
t.Fatalf("empty max parallelism env should be ignored, value=%d ok=%v", value, ok)
}
t.Setenv(envKey, "bad")
if value, ok := maxConfiguredParallelism(envKey, 2); ok || value != 0 {
t.Fatalf("invalid max parallelism env should be ignored, value=%d ok=%v", value, ok)
}
t.Setenv(envKey, "6")
if value, ok := maxConfiguredParallelism(envKey, 2); !ok || value != 6 {
t.Fatalf("expected configured max parallelism 6, value=%d ok=%v", value, ok)
}
if maxParallelism(7, 3) != 7 || maxParallelism(3, 7) != 7 {
t.Fatalf("maxParallelism should return larger value")
}
if segments := buildFullFileSourceSegments(nil, "a.bin"); segments != nil {
t.Fatalf("empty full-file segments should stay nil, got %#v", segments)
}
sourceSegments := []sharedtypes.SourceSegmentPlan{{
SourceID: "source-a",
ByteRanges: []sharedtypes.ByteRange{{
RelativePath: "a.bin",
Start: 0,
End: 10,
}},
}}
fullSegments := buildFullFileSourceSegments(sourceSegments, "a.bin")
if len(fullSegments) != 1 || len(fullSegments[0].ByteRanges) != 0 {
t.Fatalf("expected full-file source segment without ranges, got %#v", fullSegments)
}
hfSpec := sharedtypes.TransferSpec{SourceSegments: []sharedtypes.SourceSegmentPlan{{
SourceEndpoint: sharedtypes.SourceEndpoint{SourceType: "huggingface", Endpoint: "https://huggingface.co/Qwen/Qwen3-8B"},
}}}
if !shouldBypassDirectPullFFI(hfSpec, files[:1]) {
t.Fatalf("huggingface source should bypass native direct pull")
}
selfCopySpec := sharedtypes.TransferSpec{
TransferMode: sharedtypes.TransferModeSingleSourceDirect,
CollectiveSpec: sharedtypes.CollectiveSpec{
Ring: &sharedtypes.RingPeerPlan{SelfNode: "m0"},
},
SourceSegments: []sharedtypes.SourceSegmentPlan{{
SourceEndpoint: sharedtypes.SourceEndpoint{SourceType: "node", NodeName: "m0"},
}},
}
if !shouldBypassDirectPullFFI(selfCopySpec, files[:1]) {
t.Fatalf("local self-copy fanout leg should bypass native direct pull")
}
if shouldBypassDirectPullFFI(sharedtypes.TransferSpec{TransferMode: sharedtypes.TransferModeDirectStriped}, files[:1]) {
t.Fatalf("plain striped single-file transfer should keep native direct pull eligible")
}
}
func TestJobBuildHelperFallbackBranches(t *testing.T) {
t.Parallel()
root := t.TempDir()
filePath := filepath.Join(root, "single.bin")
if err := os.WriteFile(filePath, []byte("abcdef"), 0o644); err != nil {
t.Fatalf("WriteFile returned error: %v", err)
}
segment := sharedtypes.SourceSegmentPlan{
SourceID: "source-a",
SourceEndpoint: sharedtypes.SourceEndpoint{
Path: filePath,
},
}
jobs, err := buildTransferJobs(sharedtypes.TransferModeSingleSourceDirect, []sharedtypes.SourceSegmentPlan{segment}, nil, 0)
if err != nil {
t.Fatalf("buildTransferJobs returned error: %v", err)
}
if len(jobs) != 1 || jobs[0].relativePath != "single.bin" || jobs[0].length != 6 {
t.Fatalf("unexpected single source job: %#v", jobs)
}
if got := mustStatLocal(filepath.Join(root, "missing.bin")); got != 0 {
t.Fatalf("missing local stat should be 0, got %d", got)
}
fullFileJobs, err := buildTransferJobs(sharedtypes.TransferModeSingleSourceDirect, []sharedtypes.SourceSegmentPlan{segment}, []sharedtypes.ArtifactFile{
{RelativePath: "skip.bin", SizeBytes: 0},
{RelativePath: "model.bin", SizeBytes: 5},
}, 8)
if err != nil {
t.Fatalf("buildTransferJobs full file returned error: %v", err)
}
if len(fullFileJobs) != 1 || fullFileJobs[0].relativePath != "model.bin" || fullFileJobs[0].length != 5 {
t.Fatalf("unexpected full file jobs: %#v", fullFileJobs)
}
rangeJobs, err := buildTransferJobs(sharedtypes.TransferModeSingleSourceDirect, []sharedtypes.SourceSegmentPlan{{
SourceID: "peer-a",
ByteRanges: []sharedtypes.ByteRange{{
RelativePath: "model.bin",
Start: 2,
End: 11,
}},
}}, nil, 4)
if err != nil {
t.Fatalf("buildTransferJobs range returned error: %v", err)
}
if len(rangeJobs) != 3 || rangeJobs[0].offset != 2 || rangeJobs[0].length != 4 ||
rangeJobs[1].offset != 6 || rangeJobs[1].length != 4 ||
rangeJobs[2].offset != 10 || rangeJobs[2].length != 1 {
t.Fatalf("explicit single-source ranges should be chunked, got %#v", rangeJobs)
}
if _, err := buildTransferJobs(sharedtypes.TransferModeDirectStriped, []sharedtypes.SourceSegmentPlan{segment}, nil, 0); err == nil || !strings.Contains(err.Error(), "requires explicit byte ranges") {
t.Fatalf("expected explicit range error, got %v", err)
}
if _, err := buildTransferJobs(sharedtypes.TransferMode("unknown"), nil, nil, 0); err == nil || !strings.Contains(err.Error(), "unsupported transfer mode") {
t.Fatalf("expected unsupported mode error, got %v", err)
}
}
func TestNewHTTPChunkClientUsesTunedTransport(t *testing.T) {
t.Parallel()
client := NewHTTPChunkClient(nil)
transport, ok := client.client.Transport.(*http.Transport)
if !ok {
t.Fatalf("expected tuned HTTP transport, got %T", client.client.Transport)
}
if transport.MaxIdleConnsPerHost < 64 || transport.MaxIdleConns < 256 {
t.Fatalf("expected enlarged data-plane connection pool, got maxIdle=%d perHost=%d",
transport.MaxIdleConns, transport.MaxIdleConnsPerHost)
}
}
func TestFanoutSizingAndPathHelpers(t *testing.T) {
t.Parallel()
payloads := []sharedtypes.CollectiveChunkPayload{
{Chunk: sharedtypes.TransferredChunk{RelativePath: "a.bin", Offset: 0, Size: 10}},
{Chunk: sharedtypes.TransferredChunk{RelativePath: "a.bin", Offset: 10, Size: 5}},
{Chunk: sharedtypes.TransferredChunk{RelativePath: "b.bin", Offset: 3, Size: 4}},
}
relaySizes := requiredRelayFileSizes(payloads)
if relaySizes["a.bin"] != 15 || relaySizes["b.bin"] != 7 {
t.Fatalf("unexpected relay sizes: %#v", relaySizes)
}
ranges := []sharedtypes.ByteRange{
{RelativePath: "a.bin", Start: 0, End: 8},
{RelativePath: "a.bin", Start: 8, End: 12},
{RelativePath: "ignored.bin", Start: 5, End: 5},
}
rangeSizes := requiredRangeFileSizes(ranges)
if rangeSizes["a.bin"] != 12 {
t.Fatalf("unexpected range sizes: %#v", rangeSizes)
}
if got := fanoutStagingReadyMarkerRelativePath("dir/file.bin"); got != "dir/file.bin.wdfanout.ready" {
t.Fatalf("unexpected ready marker path %q", got)
}
if got := fanoutStagingDoneMarkerRelativePath("dir/file.bin"); got != "dir/file.bin.wdfanout.done" {
t.Fatalf("unexpected done marker path %q", got)
}
if got := fanoutSubTaskRelativePath(sharedtypes.TransferSpec{
LogicalManifest: sharedtypes.LogicalManifest{Files: []sharedtypes.ArtifactFile{{RelativePath: "manifest.bin"}}},
}); got != "manifest.bin" {
t.Fatalf("unexpected manifest relative path %q", got)
}
if got := fanoutSubTaskRelativePath(sharedtypes.TransferSpec{
SourceSegments: []sharedtypes.SourceSegmentPlan{{ByteRanges: []sharedtypes.ByteRange{{RelativePath: "range.bin"}}}},
}); got != "range.bin" {
t.Fatalf("unexpected range relative path %q", got)
}
}
func TestSymmetricFanoutParallelismHelpers(t *testing.T) {
t.Setenv("WD_RDMA_SYMMETRIC_SOURCE_PARALLELISM", "")
t.Setenv("WD_RDMA_SYMMETRIC_OWNER_FETCH_AFTER_INITIAL", "")
role := symmetricFanoutRole{selfPeer: sharedtypes.CollectivePeerPlan{
OwnedRanges: []sharedtypes.ByteRange{{Start: 0, End: relayFanoutBatchTargetBytes + 1}},
}}
if got, ok := derivedSymmetricFanoutSourceParallelism(role); !ok || got != 2 {
t.Fatalf("expected two derived batches, got %d ok=%v", got, ok)
}
if got, ok := derivedSymmetricFanoutSourceParallelism(symmetricFanoutRole{}); ok || got != 0 {
t.Fatalf("expected empty ranges to be unsupported, got %d ok=%v", got, ok)
}
if got := boundedSymmetricFanoutSourceParallelism(0); got != defaultSymmetricFanoutSourceParallelism {
t.Fatalf("expected default parallelism, got %d", got)
}
if got := boundedSymmetricFanoutSourceParallelism(maxSymmetricFanoutSourceParallelism + 10); got != maxSymmetricFanoutSourceParallelism {
t.Fatalf("expected max parallelism clamp, got %d", got)
}
if !symmetricOwnerFetchAfterInitial() {
t.Fatalf("expected owner fetch after initial by default")
}
t.Setenv("WD_RDMA_SYMMETRIC_OWNER_FETCH_AFTER_INITIAL", "false")
if symmetricOwnerFetchAfterInitial() {
t.Fatalf("expected owner fetch env override to disable default")
}
t.Setenv("WD_RDMA_SYMMETRIC_SOURCE_PARALLELISM", "99")
if got, ok := configuredSymmetricFanoutSourceParallelism(); !ok || got != 8 {
t.Fatalf("expected configured parallelism clamp to 8, got %d ok=%v", got, ok)
}
t.Setenv("WD_RDMA_SYMMETRIC_SOURCE_PARALLELISM", "bad")
if got, ok := configuredSymmetricFanoutSourceParallelism(); ok || got != 0 {
t.Fatalf("expected invalid configured parallelism to be ignored, got %d ok=%v", got, ok)
}
}
func TestBuildWholeFileFanoutSubSpecAssignsRankZeroOwnership(t *testing.T) {
t.Parallel()
parent := sharedtypes.TransferSpec{
TransferMode: sharedtypes.TransferModePartialPullAllGather,
SourceSegments: []sharedtypes.SourceSegmentPlan{{
SourceEndpoint: sharedtypes.SourceEndpoint{Path: "/source"},
ByteRanges: []sharedtypes.ByteRange{{RelativePath: "weights.bin", Start: 0, End: 64}},
}},
CollectiveSpec: sharedtypes.CollectiveSpec{
Ring: &sharedtypes.RingPeerPlan{Rank: 0, WorldSize: 2},
Peers: []sharedtypes.CollectivePeerPlan{
{Rank: 0, OwnedRanges: []sharedtypes.ByteRange{{RelativePath: "weights.bin", Start: 0, End: 32}}},
{Rank: 1, OwnedRanges: []sharedtypes.ByteRange{{RelativePath: "weights.bin", Start: 32, End: 64}}},
},
},
}
file := sharedtypes.ArtifactFile{RelativePath: "weights.bin", SizeBytes: 128}
subSpec := buildWholeFileFanoutSubSpec(parent, file)
if subSpec.Parallelism != 1 {
t.Fatalf("expected whole-file fanout parallelism 1, got %d", subSpec.Parallelism)
}
if got := subSpec.CollectiveSpec.Peers[0].OwnedRanges; len(got) != 1 || got[0].End != 128 {
t.Fatalf("expected rank0 to own whole file, got %#v", got)
}
if got := subSpec.CollectiveSpec.Peers[1].OwnedRanges; len(got) != 0 {
t.Fatalf("expected non-root peer ownership to be cleared, got %#v", got)
}
if len(subSpec.SourceSegments) != 1 || subSpec.SourceSegments[0].ByteRanges != nil {
t.Fatalf("expected rank0 whole-file source segment, got %#v", subSpec.SourceSegments)
}
parent.CollectiveSpec.Ring.Rank = 1
subSpec = buildWholeFileFanoutSubSpec(parent, file)
if subSpec.SourceSegments != nil {
t.Fatalf("expected non-root rank source segments to be empty, got %#v", subSpec.SourceSegments)
}
}
func TestDirectoryPerFileParallelismHelpers(t *testing.T) {
t.Setenv("WD_TCP_DIRECTORY_PER_FILE_STRIPED_PARALLELISM", "")
t.Setenv("WD_TCP_DIRECTORY_PER_FILE_FANOUT_PARALLELISM", "")
t.Setenv("WD_HF_DIRECTORY_PER_FILE_PARALLELISM", "")
t.Setenv("WD_RDMA_DIRECTORY_PER_FILE_SINGLE_PARALLELISM", "")
file := sharedtypes.ArtifactFile{RelativePath: "weights.bin", Chunkable: true}
spec := sharedtypes.TransferSpec{
TransferMode: sharedtypes.TransferModeSingleSourceDirect,
Parallelism: 16,
LogicalManifest: sharedtypes.LogicalManifest{Files: []sharedtypes.ArtifactFile{
{RelativePath: "a.bin", Chunkable: true},
{RelativePath: "b.bin", Chunkable: true},
}},
SourceSegments: []sharedtypes.SourceSegmentPlan{{SourceEndpoint: sharedtypes.SourceEndpoint{SourceType: "node", Endpoint: "n0"}}},
}
if got := directoryPerFileParallelismWithForce(spec, file, false); got != 1 {
t.Fatalf("expected RDMA local directory single pull clamp to 1, got %d", got)
}
if got := directoryPerFileParallelismWithForce(spec, file, true); got != 16 {
t.Fatalf("expected TCP fallback to keep caller parallelism, got %d", got)
}
t.Setenv("WD_RDMA_DIRECTORY_PER_FILE_SINGLE_PARALLELISM", "4")
if got := directoryPerFileParallelismWithForce(spec, file, false); got != 4 {
t.Fatalf("expected env single directory clamp, got %d", got)
}
striped := spec
striped.TransferMode = sharedtypes.TransferModeDirectStriped
striped.Parallelism = 1
if got := directoryPerFileParallelismWithForce(striped, file, true); got != minTCPStripedFileParallelism {
t.Fatalf("expected TCP striped minimum, got %d", got)
}
t.Setenv("WD_TCP_DIRECTORY_PER_FILE_STRIPED_PARALLELISM", "7")
if got := directoryPerFileParallelismWithForce(striped, file, true); got != 7 {
t.Fatalf("expected TCP striped env max, got %d", got)
}
fanout := spec
fanout.TransferMode = sharedtypes.TransferModePartialPullAllGather
fanout.Parallelism = 1
if got := directoryPerFileParallelismWithForce(fanout, file, true); got != minTCPFanoutFileParallelism {
t.Fatalf("expected TCP fanout minimum, got %d", got)
}
t.Setenv("WD_TCP_DIRECTORY_PER_FILE_FANOUT_PARALLELISM", "6")
if got := directoryPerFileParallelismWithForce(fanout, file, true); got != 6 {
t.Fatalf("expected TCP fanout env max, got %d", got)
}
hf := spec
hf.Parallelism = 2
hf.SourceSegments = []sharedtypes.SourceSegmentPlan{{
SourceEndpoint: sharedtypes.SourceEndpoint{SourceType: "huggingface", Endpoint: "https://huggingface.co/Qwen/Qwen3-8B"},
}}
if got := directoryPerFileParallelismWithForce(hf, file, false); got != minHuggingFaceFileParallelism {
t.Fatalf("expected HF minimum parallelism, got %d", got)
}
t.Setenv("WD_HF_DIRECTORY_PER_FILE_PARALLELISM", "10")
if got := directoryPerFileParallelismWithForce(hf, file, false); got != 10 {
t.Fatalf("expected HF env parallelism, got %d", got)
}
}
func TestFanoutPeerWaitTimeoutScalesForLargeDirectoryFiles(t *testing.T) {
t.Parallel()
spec := sharedtypes.TransferSpec{
TransferMode: sharedtypes.TransferModePartialPullAllGather,
TimeoutSeconds: 1,
TaskID: "task-file-001",
PreserveExisting: true,
LogicalManifest: sharedtypes.LogicalManifest{Files: []sharedtypes.ArtifactFile{{RelativePath: "weights.bin"}}},
CollectiveSpec: sharedtypes.CollectiveSpec{
SessionID: "task-file-001",
Ring: &sharedtypes.RingPeerPlan{Rank: 0, WorldSize: 2},
},
}
timeout := fanoutPeerWaitTimeout(spec, []sharedtypes.ByteRange{{Start: 0, End: 64 * 1024 * 1024 * 1024}})
if timeout <= time.Second {
t.Fatalf("expected derived timeout larger than base timeout, got %s", timeout)
}
if timeout > 15*time.Minute {
t.Fatalf("expected derived timeout to be capped, got %s", timeout)
}
}
type testContextKey string
func TestListCollectiveChunksHandlesSuccessAndFailures(t *testing.T) {
t.Parallel()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/collectives/chunks/list" {
t.Fatalf("unexpected path %s", r.URL.Path)
}
switch r.Header.Get("X-Test-Case") {
case "bad-status":
http.Error(w, "bad", http.StatusTeapot)
case "bad-json":
_, _ = w.Write([]byte("{"))
default:
_ = json.NewEncoder(w).Encode(sharedtypes.ListCollectiveChunksResponse{
ExpectedChunks: 1,
Chunks: []sharedtypes.CollectiveChunkPayload{{
Chunk: sharedtypes.TransferredChunk{RelativePath: "weights.bin", Offset: 0, Size: 8},
}},
})
}
}))
defer server.Close()
client := roundTripFunc(func(req *http.Request) (*http.Response, error) {
if req.Context().Value(testContextKey("bad-status")) == true {
req.Header.Set("X-Test-Case", "bad-status")
}
if req.Context().Value(testContextKey("bad-json")) == true {
req.Header.Set("X-Test-Case", "bad-json")
}
return http.DefaultTransport.RoundTrip(req)
})
adapter := NewAdapter(NewHTTPChunkClient(&http.Client{Transport: client}), nil)
got, err := adapter.listCollectiveChunks(context.Background(), server.URL, collectiveListRequest("task", "session", 1), "test")
if err != nil {
t.Fatalf("list chunks success returned error: %v", err)
}
if len(got.Chunks) != 1 || got.Chunks[0].Chunk.RelativePath != "weights.bin" {
t.Fatalf("unexpected chunk list response: %#v", got)
}
statusCtx := context.WithValue(context.Background(), testContextKey("bad-status"), true)
if _, err := adapter.listCollectiveChunks(statusCtx, server.URL, collectiveListRequest("task", "session", 1), "test"); err == nil {
t.Fatalf("expected bad status error")
}
jsonCtx := context.WithValue(context.Background(), testContextKey("bad-json"), true)
if _, err := adapter.listCollectiveChunks(jsonCtx, server.URL, collectiveListRequest("task", "session", 1), "test"); err == nil {
t.Fatalf("expected bad json error")
}
}
func TestHandleRelayCollectiveMetadataPoll(t *testing.T) {
t.Parallel()
adapter := NewAdapter(LocalChunkClient{}, nil)
chunks := []sharedtypes.CollectiveChunkPayload{{Chunk: sharedtypes.TransferredChunk{RelativePath: "weights.bin"}}}
ticker := time.NewTicker(time.Hour)
defer ticker.Stop()
got, complete, err := adapter.handleRelayCollectiveMetadataPoll(
context.Background(),
"task",
1,
time.Now().Add(time.Second),
sharedtypes.ListCollectiveChunksResponse{ExpectedChunks: 1, Chunks: chunks},
ticker,
)
if err != nil || !complete || len(got) != 1 {
t.Fatalf("expected complete metadata poll, chunks=%#v complete=%v err=%v", got, complete, err)
}
ctx, cancel := context.WithCancel(context.Background())
cancel()
_, complete, err = adapter.handleRelayCollectiveMetadataPoll(
ctx,
"task",
1,
time.Now().Add(time.Second),
sharedtypes.ListCollectiveChunksResponse{ExpectedChunks: 2, Chunks: chunks},
ticker,
)
if err == nil || complete {
t.Fatalf("expected canceled incomplete metadata poll, complete=%v err=%v", complete, err)
}
}
func TestCollectiveRelayHelperDefaultAndErrorBranches(t *testing.T) {
t.Parallel()
if err := validateRelayPeerFetchRole(relayFanoutRole{isRoot: true}); err != nil {
t.Fatalf("root relay role should be valid: %v", err)
}
if err := validateRelayPeerFetchRole(relayFanoutRole{rootPeer: sharedtypes.CollectivePeerPlan{StagingPath: "/staging"}}); err == nil || !strings.Contains(err.Error(), "root endpoint") {
t.Fatalf("expected missing endpoint error, got %v", err)
}
if err := validateRelayPeerFetchRole(relayFanoutRole{rootPeer: sharedtypes.CollectivePeerPlan{Endpoint: "http://root"}}); err == nil || !strings.Contains(err.Error(), "root staging") {
t.Fatalf("expected missing staging path error, got %v", err)
}
adapter := NewAdapterWithOptions(LocalChunkClient{}, slog.New(slog.NewTextHandler(io.Discard, nil)), AdapterOptions{
EnableFFI: false,
ForceTCPFallback: true,
})
result, transport, err := adapter.runRelayPeerFetch(context.Background(), relayFetchContext{
spec: sharedtypes.TransferSpec{TaskID: "relay-default"},
role: relayFanoutRole{rootPeer: sharedtypes.CollectivePeerPlan{NodeName: "root"}},
}, relayFetchState{})
if err != nil {
t.Fatalf("runRelayPeerFetch default branch returned error: %v", err)
}
if result.BytesTransferred != 0 || transport != sharedtypes.TransportPathRDMA {
t.Fatalf("unexpected default relay result: result=%#v transport=%s", result, transport)
}
hints := countRelayHints([]sharedtypes.CollectiveChunkPayload{
{RelayRDMA: &sharedtypes.RelayRDMAHint{SessionID: "session-a"}},
{RelayRDMA: &sharedtypes.RelayRDMAHint{}},
{},
})
if hints != 1 {
t.Fatalf("expected one relay hint with session id, got %d", hints)
}
segmentKinds := collectiveSegmentKinds([]sharedtypes.SourceSegmentPlan{{
SourceEndpoint: sharedtypes.SourceEndpoint{RelayRDMA: &sharedtypes.RelayRDMAHint{SessionID: "session-a"}},
}})
if len(segmentKinds) != 1 || segmentKinds[0] != "relay:session-a:ranges=0:first=-1--1" {
t.Fatalf("unexpected relay segment kinds: %#v", segmentKinds)
}
if firstRangeStart(nil) != -1 || firstRangeEnd(nil) != -1 {
t.Fatalf("empty range helpers should return -1")
}
fetch := relayFetchContext{
spec: sharedtypes.TransferSpec{TaskID: "relay-error"},
role: relayFanoutRole{rootPeer: sharedtypes.CollectivePeerPlan{NodeName: "root"}},
resultCh: make(chan relayFetchRun, 1),
}
fetch.resultCh <- relayFetchRun{iteration: 7, err: errors.New("relay failed")}
_, err = adapter.collectRelayPeerFetchRun(context.Background(), fetch, relayFetchState{inflight: 1})
if err == nil || !strings.Contains(err.Error(), "relay failed") {
t.Fatalf("expected relay run error, got %v", err)
}
}