* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* openFuyao is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
* http://license.coscl.org.cn/MulanPSL2
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
*/
package rdma
import (
"context"
"fmt"
"os"
"strconv"
"strings"
"sync"
"time"
"github.com/openfuyao/weight-dispatcher/pkg/dataplane/rdmaffi"
"github.com/openfuyao/weight-dispatcher/pkg/internal/errutil"
sharedtypes "github.com/openfuyao/weight-dispatcher/pkg/types"
)
const (
defaultDirectoryFanoutWindow = 2
defaultTCPFanoutSourceAhead = 10
defaultTCPFanoutExchangeWorkers = 8
minTCPStripedFileParallelism int32 = 4
minTCPFanoutFileParallelism int32 = 24
minHuggingFaceFileParallelism int32 = 8
)
type directoryFanoutPipelineItem struct {
idx int
file sharedtypes.ArtifactFile
subSpec sharedtypes.TransferSpec
np nPeerFanoutSpec
srcResult sharedtypes.TransferResult
srcTp sharedtypes.TransportPath
isAtomic bool
err error
}
type directoryFanoutPipelineResult struct {
idx int
file sharedtypes.ArtifactFile
subSpec sharedtypes.TransferSpec
result sharedtypes.TransferResult
tp sharedtypes.TransportPath
err error
}
func shouldExecuteDirectoryPerFile(spec sharedtypes.TransferSpec, files []sharedtypes.ArtifactFile) bool {
if len(files) <= 1 {
return false
}
switch spec.TransferMode {
case sharedtypes.TransferModeSingleSourceDirect, sharedtypes.TransferModeDirectStriped, sharedtypes.TransferModePartialPullAllGather:
return true
default:
return false
}
}
func (a *Adapter) executeDirectoryPerFile(
ctx context.Context,
start time.Time,
spec sharedtypes.TransferSpec,
files []sharedtypes.ArtifactFile,
resolveMs int64,
) (sharedtypes.TransferResult, error) {
if err := resetDirectoryTransferTarget(spec.TargetTempPath); err != nil {
a.logger.Error("prepare per-file directory target failed", "taskID", spec.TaskID, "targetTempPath", spec.TargetTempPath, "err", err)
return sharedtypes.TransferResult{}, err
}
if shouldPipelineDirectoryFanout(spec, files, a.forceTCPFallback) {
return a.executeDirectoryFanoutPipelined(ctx, start, spec, files, resolveMs)
}
aggregate := sharedtypes.TransferResult{}
transportPath := sharedtypes.TransportPath("")
fanoutWindow := directoryFanoutWindow(spec, files)
for idx := range files {
subResult, nextPath, err := a.executeDirectoryPerFileStep(ctx, spec, files, idx, fanoutWindow)
if err != nil {
return sharedtypes.TransferResult{}, err
}
aggregate = mergeTransferResults(aggregate, subResult)
transportPath = mergeTransportPath(transportPath, nextPath)
}
result := finalizeTransferResult(start, spec, files, aggregate, transportPath)
a.logDirectoryPerFileCompleted(start, spec, files, result, resolveMs)
return result, nil
}
func shouldPipelineDirectoryFanout(spec sharedtypes.TransferSpec, files []sharedtypes.ArtifactFile, forceTCPFallback bool) bool {
return forceTCPFallback &&
len(files) > 1 &&
spec.TransferMode == sharedtypes.TransferModePartialPullAllGather &&
spec.CollectiveSpec.Ring != nil
}
func tcpDirectoryFanoutSourceAhead(files []sharedtypes.ArtifactFile) int {
return boundedPipelineConfig("WD_TCP_FANOUT_SOURCE_AHEAD", defaultTCPFanoutSourceAhead, len(files))
}
func tcpDirectoryFanoutExchangeWorkers(files []sharedtypes.ArtifactFile) int {
return boundedPipelineConfig("WD_TCP_FANOUT_EXCHANGE_WORKERS", defaultTCPFanoutExchangeWorkers, len(files))
}
func boundedPipelineConfig(envKey string, defaultValue, fileCount int) int {
if fileCount <= 1 {
return 1
}
value := defaultValue
if raw := strings.TrimSpace(os.Getenv(envKey)); raw != "" {
parsed, err := strconv.ParseInt(raw, 10, 32)
if err == nil {
value = int(parsed)
}
}
switch {
case value < 1:
return 1
case value > fileCount:
return fileCount
default:
return value
}
}
func (a *Adapter) executeDirectoryFanoutPipelined(
ctx context.Context,
start time.Time,
spec sharedtypes.TransferSpec,
files []sharedtypes.ArtifactFile,
resolveMs int64,
) (sharedtypes.TransferResult, error) {
ctx, cancel := context.WithCancel(ctx)
defer cancel()
sourceAhead := tcpDirectoryFanoutSourceAhead(files)
exchangeWorkers := tcpDirectoryFanoutExchangeWorkers(files)
resultCh := a.startDirectoryFanoutPipeline(ctx, spec, files, sourceAhead, exchangeWorkers)
aggregate, transportPath, err := collectDirectoryFanoutPipelineResults(cancel, resultCh)
if err != nil {
return sharedtypes.TransferResult{}, err
}
result := finalizeTransferResult(start, spec, files, aggregate, transportPath)
a.logDirectoryPerFileCompleted(start, spec, files, result, resolveMs, "sourceAhead", sourceAhead, "exchangeWorkers", exchangeWorkers)
return result, nil
}
func (a *Adapter) startDirectoryFanoutPipeline(
ctx context.Context,
spec sharedtypes.TransferSpec,
files []sharedtypes.ArtifactFile,
sourceAhead int,
exchangeWorkers int,
) <-chan directoryFanoutPipelineResult {
itemCh := make(chan directoryFanoutPipelineItem, sourceAhead)
resultCh := make(chan directoryFanoutPipelineResult, exchangeWorkers)
go a.produceDirectoryFanoutPipelineItems(ctx, spec, files, itemCh)
a.startDirectoryFanoutExchangeWorkers(ctx, spec, itemCh, resultCh, exchangeWorkers)
return resultCh
}
func (a *Adapter) produceDirectoryFanoutPipelineItems(
ctx context.Context,
spec sharedtypes.TransferSpec,
files []sharedtypes.ArtifactFile,
ch chan<- directoryFanoutPipelineItem,
) {
defer close(ch)
for idx, file := range files {
item := a.buildDirectoryFanoutPipelineItem(ctx, spec, file, idx)
if !sendDirectoryFanoutPipelineItem(ctx, ch, item) || item.err != nil {
return
}
}
}
func (a *Adapter) buildDirectoryFanoutPipelineItem(
ctx context.Context,
spec sharedtypes.TransferSpec,
file sharedtypes.ArtifactFile,
idx int,
) directoryFanoutPipelineItem {
subSpec := buildPerFileExecutionSpec(spec, file, idx, a.forceTCPFallback)
np, ok := resolveNPeerFanoutSpec(subSpec)
if !ok {
result, err := a.Execute(ctx, subSpec)
return directoryFanoutPipelineItem{
idx: idx, file: file, subSpec: subSpec,
srcResult: result, srcTp: result.TransportPath, isAtomic: true, err: err,
}
}
if prepared, err := prepareTargetFilesPreserve(subSpec.TargetTempPath, []sharedtypes.ArtifactFile{file}); err != nil {
return directoryFanoutPipelineItem{idx: idx, file: file, subSpec: subSpec, err: err}
} else {
closeOpenFiles(prepared)
}
srcResult, srcTp, err := a.runFanoutSourcePhase(ctx, subSpec, np)
return directoryFanoutPipelineItem{
idx: idx, file: file, subSpec: subSpec, np: np,
srcResult: srcResult, srcTp: srcTp, err: err,
}
}
func (a *Adapter) startDirectoryFanoutExchangeWorkers(
ctx context.Context,
spec sharedtypes.TransferSpec,
itemCh <-chan directoryFanoutPipelineItem,
resultCh chan<- directoryFanoutPipelineResult,
exchangeWorkers int,
) {
var wg sync.WaitGroup
for worker := 0; worker < exchangeWorkers; worker++ {
wg.Add(1)
go func() {
defer wg.Done()
for item := range itemCh {
result := a.executeDirectoryFanoutPipelineItem(ctx, spec, item)
if !sendDirectoryFanoutPipelineResult(ctx, resultCh, result) || result.err != nil {
return
}
}
}()
}
go func() {
wg.Wait()
close(resultCh)
}()
}
func collectDirectoryFanoutPipelineResults(
cancel context.CancelFunc,
resultCh <-chan directoryFanoutPipelineResult,
) (sharedtypes.TransferResult, sharedtypes.TransportPath, error) {
aggregate := sharedtypes.TransferResult{}
transportPath := sharedtypes.TransportPath("")
var firstErr error
for result := range resultCh {
if result.err != nil {
cancel()
if firstErr == nil {
firstErr = result.err
}
continue
}
aggregate = mergeTransferResults(aggregate, result.result)
transportPath = mergeTransportPath(transportPath, result.tp)
}
return aggregate, transportPath, firstErr
}
func (a *Adapter) executeDirectoryFanoutPipelineItem(
ctx context.Context,
spec sharedtypes.TransferSpec,
item directoryFanoutPipelineItem,
) directoryFanoutPipelineResult {
if item.err != nil {
return directoryFanoutPipelineResult{
idx: item.idx,
file: item.file,
subSpec: item.subSpec,
err: fmt.Errorf("per-file fanout source for %s (file %d): %w", item.file.RelativePath, item.idx+1, item.err),
}
}
subStarted := time.Now()
var merged sharedtypes.TransferResult
var mergedTp sharedtypes.TransportPath
if item.isAtomic {
merged = item.srcResult
mergedTp = item.srcTp
} else {
exResult, exTp, exErr := a.runFanoutExchangePhase(ctx, item.subSpec, item.np)
if exErr != nil {
return directoryFanoutPipelineResult{
idx: item.idx,
file: item.file,
subSpec: item.subSpec,
err: fmt.Errorf("per-file fanout exchange for %s (file %d): %w", item.file.RelativePath, item.idx+1, exErr),
}
}
merged = mergeTransferResults(item.srcResult, exResult)
mergedTp = mergeTransportPath(item.srcTp, exTp)
}
if err := markFanoutFileDone(item.subSpec); err != nil {
return directoryFanoutPipelineResult{idx: item.idx, file: item.file, subSpec: item.subSpec, err: err}
}
logDirectoryPerFileStep(a, spec.TaskID, item.file, item.subSpec, merged, subStarted)
return directoryFanoutPipelineResult{
idx: item.idx,
file: item.file,
subSpec: item.subSpec,
result: merged,
tp: mergedTp,
}
}
func sendDirectoryFanoutPipelineItem(
ctx context.Context,
ch chan<- directoryFanoutPipelineItem,
item directoryFanoutPipelineItem,
) bool {
select {
case ch <- item:
return true
case <-ctx.Done():
return false
}
}
func drainDirectoryFanoutPipeline(ch <-chan directoryFanoutPipelineItem) {
for range ch {
}
}
func sendDirectoryFanoutPipelineResult(
ctx context.Context,
ch chan<- directoryFanoutPipelineResult,
result directoryFanoutPipelineResult,
) bool {
select {
case ch <- result:
return true
case <-ctx.Done():
return false
}
}
func (a *Adapter) logDirectoryPerFileCompleted(
start time.Time,
spec sharedtypes.TransferSpec,
files []sharedtypes.ArtifactFile,
result sharedtypes.TransferResult,
resolveMs int64,
extra ...any,
) {
attrs := []any{
"taskID", spec.TaskID,
"mode", spec.TransferMode,
"scenario", classifyTransferScenario(spec),
"fileCount", len(files),
"bytes", result.BytesTransferred,
"transportPath", result.TransportPath,
"throughputMBps", result.ThroughputMBps,
}
attrs = append(attrs, extra...)
attrs = append(attrs, "resolveFilesMs", resolveMs, "totalExecuteMs", time.Since(start).Milliseconds())
a.logger.Info("data-plane directory-per-file transfer completed", attrs...)
}
func resetDirectoryTransferTarget(targetTempPath string) error {
if err := os.RemoveAll(targetTempPath); err != nil {
return fmt.Errorf("reset temp path for per-file directory transfer: %w", err)
}
if err := os.MkdirAll(targetTempPath, rdmaDirPerm); err != nil {
return fmt.Errorf("create temp path for per-file directory transfer: %w", err)
}
return nil
}
func (a *Adapter) executeDirectoryPerFileStep(
ctx context.Context,
spec sharedtypes.TransferSpec,
files []sharedtypes.ArtifactFile,
idx int,
fanoutWindow int,
) (sharedtypes.TransferResult, sharedtypes.TransportPath, error) {
if err := a.waitForDirectoryFanoutSlot(ctx, spec, files, idx, fanoutWindow); err != nil {
a.logger.Error("wait directory fanout slot failed", "taskID", spec.TaskID, "relativePath", files[idx].RelativePath, "index", idx, "err", err)
return sharedtypes.TransferResult{}, "", err
}
subSpec := buildPerFileExecutionSpec(spec, files[idx], idx, a.forceTCPFallback)
subStarted := time.Now()
subResult, err := a.Execute(ctx, subSpec)
if err != nil {
a.logger.Error("execute per-file directory transfer failed", "taskID", spec.TaskID, "subTaskID", subSpec.TaskID, "relativePath", files[idx].RelativePath, "err", err)
return sharedtypes.TransferResult{}, "", fmt.Errorf("per-file transfer failed for %s: %w", files[idx].RelativePath, err)
}
if err := markFanoutFileDone(subSpec); err != nil {
a.logger.Error("mark directory fanout done failed", "taskID", spec.TaskID, "subTaskID", subSpec.TaskID, "relativePath", files[idx].RelativePath, "err", err)
return sharedtypes.TransferResult{}, "", fmt.Errorf("mark fanout done for %s: %w", files[idx].RelativePath, err)
}
logDirectoryPerFileStep(a, spec.TaskID, files[idx], subSpec, subResult, subStarted)
return subResult, subResult.TransportPath, nil
}
func (a *Adapter) waitForDirectoryFanoutSlot(
ctx context.Context,
spec sharedtypes.TransferSpec,
files []sharedtypes.ArtifactFile,
idx int,
fanoutWindow int,
) error {
if idx < fanoutWindow {
return nil
}
prevIdx := idx - fanoutWindow
prevSubSpec := buildPerFileExecutionSpec(spec, files[prevIdx], prevIdx, a.forceTCPFallback)
if err := waitForPreviousDirectoryFanoutFile(ctx, a.client, prevSubSpec, files[prevIdx]); err != nil {
return fmt.Errorf("wait previous fanout file %s done: %w", files[prevIdx].RelativePath, err)
}
return nil
}
func buildPerFileExecutionSpec(
spec sharedtypes.TransferSpec,
file sharedtypes.ArtifactFile,
idx int,
forceTCPFallback bool,
) sharedtypes.TransferSpec {
subSpec := buildPerFileTransferSpecWithForce(spec, file, forceTCPFallback)
subSpec.TaskID = fmt.Sprintf("%s-file-%03d", spec.TaskID, idx+1)
subSpec.PreserveExisting = true
if subSpec.TransferMode == sharedtypes.TransferModePartialPullAllGather && len(subSpec.CollectiveSpec.Peers) > 0 {
subSpec.CollectiveSpec.SessionID = fmt.Sprintf("%s-file-%03d", collectiveSessionID(spec), idx+1)
}
return subSpec
}
func logDirectoryPerFileStep(
a *Adapter,
taskID string,
file sharedtypes.ArtifactFile,
subSpec sharedtypes.TransferSpec,
subResult sharedtypes.TransferResult,
subStarted time.Time,
) {
a.logger.Info(
"per-file directory transfer completed",
"taskID", taskID,
"subTaskID", subSpec.TaskID,
"mode", subSpec.TransferMode,
"relativePath", file.RelativePath,
"sizeBytes", file.SizeBytes,
"bytesTransferred", subResult.BytesTransferred,
"transportPath", subResult.TransportPath,
"elapsedMs", time.Since(subStarted).Milliseconds(),
)
}
func directoryFanoutWindow(spec sharedtypes.TransferSpec, files []sharedtypes.ArtifactFile) int {
if len(files) <= 1 {
return 1
}
if spec.TransferMode != sharedtypes.TransferModePartialPullAllGather || spec.CollectiveSpec.Ring == nil {
return 1
}
raw := strings.TrimSpace(os.Getenv("WD_DIRECTORY_FANOUT_WINDOW"))
if raw != "" {
if value64, err := strconv.ParseInt(raw, 10, 32); err == nil {
value := int(value64)
switch {
case value < 1:
return 1
case value > len(files):
return len(files)
default:
return value
}
}
}
if len(files) < defaultDirectoryFanoutWindow {
return 1
}
return defaultDirectoryFanoutWindow
}
func buildPerFileTransferSpec(spec sharedtypes.TransferSpec, file sharedtypes.ArtifactFile) sharedtypes.TransferSpec {
return buildPerFileTransferSpecWithForce(spec, file, forceTCPFallbackDataPlane())
}
func buildPerFileTransferSpecWithForce(spec sharedtypes.TransferSpec, file sharedtypes.ArtifactFile, forceTCPFallback bool) sharedtypes.TransferSpec {
subSpec := spec
subSpec.LogicalManifest = spec.LogicalManifest
subSpec.LogicalManifest.Files = []sharedtypes.ArtifactFile{file}
subSpec.SourceSegments = filterSourceSegmentsForFile(spec.TransferMode, spec.SourceSegments, file.RelativePath)
subSpec.CollectiveSpec = filterCollectiveSpecForFile(spec.CollectiveSpec, file.RelativePath)
subSpec.Parallelism = directoryPerFileParallelismWithForce(spec, file, forceTCPFallback)
if spec.TransferMode == sharedtypes.TransferModePartialPullAllGather {
if !file.Chunkable {
return buildWholeFileFanoutSubSpec(spec, file)
}
if countSegmentRanges(subSpec.SourceSegments) == 0 {
subSpec.TransferMode = sharedtypes.TransferModeSingleSourceDirect
subSpec.CollectiveSpec = sharedtypes.CollectiveSpec{}
subSpec.SourceSegments = buildFullFileSourceSegments(spec.SourceSegments, file.RelativePath)
}
}
return subSpec
}
func buildWholeFileFanoutSubSpec(parent sharedtypes.TransferSpec, file sharedtypes.ArtifactFile) sharedtypes.TransferSpec {
subSpec := parent
subSpec.LogicalManifest = parent.LogicalManifest
subSpec.LogicalManifest.Files = []sharedtypes.ArtifactFile{file}
subSpec.Parallelism = 1
subSpec.CollectiveSpec = filterCollectiveSpecForFile(parent.CollectiveSpec, file.RelativePath)
if subSpec.CollectiveSpec.Ring == nil || len(subSpec.CollectiveSpec.Peers) == 0 {
subSpec.TransferMode = sharedtypes.TransferModeSingleSourceDirect
subSpec.CollectiveSpec = sharedtypes.CollectiveSpec{}
subSpec.SourceSegments = buildFullFileSourceSegments(parent.SourceSegments, file.RelativePath)
return subSpec
}
fullRange := sharedtypes.ByteRange{
RelativePath: file.RelativePath,
Start: 0,
End: file.SizeBytes,
}
for idx := range subSpec.CollectiveSpec.Peers {
if subSpec.CollectiveSpec.Peers[idx].Rank == 0 {
subSpec.CollectiveSpec.Peers[idx].OwnedRanges = []sharedtypes.ByteRange{fullRange}
} else {
subSpec.CollectiveSpec.Peers[idx].OwnedRanges = nil
}
}
if subSpec.CollectiveSpec.Ring.Rank == 0 {
subSpec.SourceSegments = buildFullFileSourceSegments(parent.SourceSegments, file.RelativePath)
} else {
subSpec.SourceSegments = nil
}
return subSpec
}
func directoryPerFileParallelism(spec sharedtypes.TransferSpec, file sharedtypes.ArtifactFile) int32 {
return directoryPerFileParallelismWithForce(spec, file, forceTCPFallbackDataPlane())
}
func directoryPerFileParallelismWithForce(spec sharedtypes.TransferSpec, file sharedtypes.ArtifactFile, forceTCPFallback bool) int32 {
parallelism := spec.Parallelism
if parallelism < 1 {
parallelism = 1
}
if len(spec.LogicalManifest.Files) <= 1 || !file.Chunkable {
return parallelism
}
if forceTCPFallback && spec.TransferMode == sharedtypes.TransferModeDirectStriped {
return tcpStripedDirectoryParallelism(parallelism)
}
if forceTCPFallback && spec.TransferMode == sharedtypes.TransferModePartialPullAllGather {
return tcpFanoutDirectoryParallelism(parallelism)
}
if spec.TransferMode != sharedtypes.TransferModeSingleSourceDirect {
return parallelism
}
if hasHuggingFaceDirectorySource(spec.SourceSegments) {
return huggingFaceDirectoryParallelism(parallelism)
}
if raw := strings.TrimSpace(os.Getenv("WD_RDMA_DIRECTORY_PER_FILE_SINGLE_PARALLELISM")); raw != "" {
if value64, err := strconv.ParseInt(raw, 10, 32); err == nil {
value := int32(value64)
switch {
case value < 1:
return 1
case value < parallelism:
return value
default:
return parallelism
}
}
}
if forceTCPFallback {
return parallelism
}
if parallelism > 1 {
return 1
}
return parallelism
}
func tcpStripedDirectoryParallelism(parallelism int32) int32 {
if value, ok := maxConfiguredParallelism("WD_TCP_DIRECTORY_PER_FILE_STRIPED_PARALLELISM", parallelism); ok {
return value
}
return maxParallelism(parallelism, minTCPStripedFileParallelism)
}
func tcpFanoutDirectoryParallelism(parallelism int32) int32 {
if value, ok := maxConfiguredParallelism("WD_TCP_DIRECTORY_PER_FILE_FANOUT_PARALLELISM", parallelism); ok {
return value
}
return maxParallelism(parallelism, minTCPFanoutFileParallelism)
}
func hasHuggingFaceDirectorySource(segments []sharedtypes.SourceSegmentPlan) bool {
for _, segment := range segments {
if IsHuggingFaceEndpoint(segment.SourceEndpoint.SourceType, segment.SourceEndpoint.Endpoint) {
return true
}
}
return false
}
func huggingFaceDirectoryParallelism(parallelism int32) int32 {
if value, ok := maxConfiguredParallelism("WD_HF_DIRECTORY_PER_FILE_PARALLELISM", parallelism); ok {
return value
}
return maxParallelism(parallelism, minHuggingFaceFileParallelism)
}
func maxConfiguredParallelism(envKey string, parallelism int32) (int32, bool) {
raw := strings.TrimSpace(os.Getenv(envKey))
if raw == "" {
return 0, false
}
value64, err := strconv.ParseInt(raw, 10, 32)
if err != nil || value64 < 1 {
return 0, false
}
return maxParallelism(parallelism, int32(value64)), true
}
func maxParallelism(left, right int32) int32 {
if left >= right {
return left
}
return right
}
func buildFullFileSourceSegments(segments []sharedtypes.SourceSegmentPlan, _ string) []sharedtypes.SourceSegmentPlan {
if len(segments) == 0 {
return nil
}
segment := segments[0]
segment.ByteRanges = nil
return []sharedtypes.SourceSegmentPlan{segment}
}
func filterSourceSegmentsForFile(
mode sharedtypes.TransferMode,
segments []sharedtypes.SourceSegmentPlan,
relativePath string,
) []sharedtypes.SourceSegmentPlan {
filtered := make([]sharedtypes.SourceSegmentPlan, 0, len(segments))
for _, segment := range segments {
if len(segment.ByteRanges) == 0 {
if mode == sharedtypes.TransferModeSingleSourceDirect {
filtered = append(filtered, segment)
}
continue
}
cloned := segment
cloned.ByteRanges = filterByteRangesForFile(segment.ByteRanges, relativePath)
if len(cloned.ByteRanges) == 0 {
continue
}
filtered = append(filtered, cloned)
}
return filtered
}
func filterCollectiveSpecForFile(spec sharedtypes.CollectiveSpec, relativePath string) sharedtypes.CollectiveSpec {
filtered := spec
if len(spec.Peers) == 0 {
return filtered
}
filtered.Peers = make([]sharedtypes.CollectivePeerPlan, 0, len(spec.Peers))
for _, peer := range spec.Peers {
cloned := peer
cloned.OwnedRanges = filterByteRangesForFile(peer.OwnedRanges, relativePath)
filtered.Peers = append(filtered.Peers, cloned)
}
return filtered
}
func filterByteRangesForFile(ranges []sharedtypes.ByteRange, relativePath string) []sharedtypes.ByteRange {
filtered := make([]sharedtypes.ByteRange, 0, len(ranges))
for _, rng := range ranges {
if rng.RelativePath == relativePath {
filtered = append(filtered, rng)
}
}
return filtered
}
func countSegmentRanges(segments []sharedtypes.SourceSegmentPlan) int {
total := 0
for _, segment := range segments {
total += len(segment.ByteRanges)
}
return total
}
func shouldBypassDirectPullFFI(spec sharedtypes.TransferSpec, files []sharedtypes.ArtifactFile) bool {
if injectChunkCRCMismatchEnabled() {
return true
}
if transferUsesHuggingFaceSource(spec) {
return true
}
if spec.TransferMode == sharedtypes.TransferModeSingleSourceDirect && len(files) > 1 {
return true
}
if spec.TransferMode == sharedtypes.TransferModeSingleSourceDirect &&
spec.CollectiveSpec.Ring != nil &&
len(files) == 1 &&
len(spec.SourceSegments) == 1 &&
strings.EqualFold(spec.SourceSegments[0].SourceEndpoint.SourceType, "node") &&
spec.SourceSegments[0].SourceEndpoint.NodeName != "" &&
spec.SourceSegments[0].SourceEndpoint.NodeName == targetNodeID(spec) {
return true
}
return false
}
func (a *Adapter) tryExecuteDirectPullWithFFI(spec sharedtypes.TransferSpec) (sharedtypes.TransferResult, error) {
if a.forceTCPFallback {
return sharedtypes.TransferResult{}, fmt.Errorf("tcp fallback forced")
}
normalized, err := normalizeTransferSpecEndpoints(spec)
if err != nil {
return sharedtypes.TransferResult{}, err
}
directPullSemaphore <- struct{}{}
defer func() { <-directPullSemaphore }()
result, err := rdmaffi.ExecuteDirectPull(normalized)
if err != nil {
return sharedtypes.TransferResult{}, errutil.Wrap("execute direct pull via rdma ffi", err)
}
if result.TransportPath == "" {
a.logger.Warn("rust direct-pull FFI returned empty transport path, coercing to TCP fallback", "taskID", spec.TaskID, "mode", spec.TransferMode)
result.TransportPath = sharedtypes.TransportPathTCPFallback
}
return result, nil
}
func (a *Adapter) determineTransportPath() sharedtypes.TransportPath {
if a.forceTCPFallback {
return sharedtypes.TransportPathTCPFallback
}
switch a.client.(type) {
case LocalChunkClient, *LocalChunkClient, *HTTPChunkClient:
return sharedtypes.TransportPathTCPFallback
default:
if !a.enableFFI {
return sharedtypes.TransportPathTCPFallback
}
return rdmaffi.PreferredTransport()
}
}