* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* openFuyao is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
* http://license.coscl.org.cn/MulanPSL2
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
*/
package rdma
import (
"context"
"fmt"
"log/slog"
"net/http"
"os"
"strings"
"sync/atomic"
"time"
"github.com/openfuyao/weight-dispatcher/pkg/dataplane"
sharedtypes "github.com/openfuyao/weight-dispatcher/pkg/types"
)
var _ dataplane.Adapter = (*Adapter)(nil)
var directPullSemaphore = make(chan struct{}, 3)
const (
defaultSymmetricFanoutSourceParallelism int32 = 3
maxSymmetricFanoutSourceParallelism int32 = 4
sourceTypePeer = "peer"
)
type routedTransferResult struct {
result sharedtypes.TransferResult
handled bool
}
type transferTiming struct {
start time.Time
resolveMs int64
initialTransferMs int64
collectiveMs int64
usedFFI bool
}
type initialTransferStageResult struct {
result sharedtypes.TransferResult
elapsedMs int64
usedFFI bool
openFiles map[string]*os.File
}
type jobExecutionPhaseResult struct {
result sharedtypes.TransferResult
buildJobsMs int64
executeJobsMs int64
}
type preparedSubSpecTargets struct {
openFiles map[string]*os.File
prepareMs int64
}
func forceTCPFallbackDataPlane() bool {
raw := strings.TrimSpace(os.Getenv("WD_FORCE_TCP_FALLBACK"))
if raw == "" {
return false
}
switch strings.ToLower(raw) {
case "1", "true", "yes", "on":
return true
default:
return false
}
}
var injectedChunkCRCMismatch atomic.Bool
func injectChunkCRCMismatchEnabled() bool {
raw := strings.TrimSpace(os.Getenv("WD_INJECT_CHUNK_CRC_MISMATCH_ONCE"))
switch strings.ToLower(raw) {
case "1", "true", "yes", "on":
return true
default:
return false
}
}
func maybeInjectChunkCRCMismatch(payload []byte) {
if !injectChunkCRCMismatchEnabled() || len(payload) == 0 {
return
}
if injectedChunkCRCMismatch.CompareAndSwap(false, true) {
payload[0] ^= 0x1
}
}
func effectiveChunkRetryLimit(spec sharedtypes.TransferSpec) int {
if spec.RetryLimit > 0 {
return int(spec.RetryLimit)
}
return 1
}
type Adapter struct {
client ChunkClient
httpClient *http.Client
logger *slog.Logger
enableFFI bool
forceTCPFallback bool
}
type AdapterOptions struct {
EnableFFI bool
ForceTCPFallback bool
}
func transferUsesHuggingFaceSource(spec sharedtypes.TransferSpec) bool {
for _, segment := range spec.SourceSegments {
if IsHuggingFaceEndpoint(segment.SourceEndpoint.SourceType, segment.SourceEndpoint.Endpoint) {
return true
}
}
return false
}
const (
relayFanoutBatchTargetBytes int64 = 512 * 1024 * 1024
relayPeerMaxInflightIterations = 4
)
func NewAdapter(client ChunkClient, logger *slog.Logger) *Adapter {
return NewAdapterWithOptions(client, logger, AdapterOptions{
EnableFFI: true,
ForceTCPFallback: forceTCPFallbackDataPlane(),
})
}
func NewAdapterWithFFI(client ChunkClient, logger *slog.Logger, enableFFI bool) *Adapter {
return NewAdapterWithOptions(client, logger, AdapterOptions{
EnableFFI: enableFFI,
ForceTCPFallback: forceTCPFallbackDataPlane(),
})
}
func NewAdapterWithOptions(client ChunkClient, logger *slog.Logger, options AdapterOptions) *Adapter {
if client == nil {
client = LocalChunkClient{}
}
if logger == nil {
logger = slog.Default()
}
httpClient := newDataPlaneHTTPClient()
if httpChunkClient, ok := client.(*HTTPChunkClient); ok && httpChunkClient.client != nil {
httpClient = httpChunkClient.client
}
return &Adapter{
client: client,
httpClient: httpClient,
logger: logger,
enableFFI: options.EnableFFI,
forceTCPFallback: options.ForceTCPFallback || forceTCPFallbackDataPlane(),
}
}
func (a *Adapter) Execute(ctx context.Context, spec sharedtypes.TransferSpec) (sharedtypes.TransferResult, error) {
start := time.Now()
files, resolveMs, err := a.resolveExecuteFiles(ctx, spec)
if err != nil {
return sharedtypes.TransferResult{}, err
}
parallelism := normalizedParallelism(spec)
a.logTransferStart(spec, files, parallelism, resolveMs)
if routedResult, routedErr := a.tryExecuteRoutedTransfer(ctx, start, spec, files, resolveMs); routedResult.handled {
return routedResult.result, routedErr
}
initialStage, err := a.executeInitialTransferStage(ctx, spec, files, parallelism)
if err != nil {
return sharedtypes.TransferResult{}, err
}
if initialStage.openFiles != nil {
defer closeOpenFiles(initialStage.openFiles)
}
result, collectiveMs, transportPath, err := a.executeCollectiveStage(ctx, spec, initialStage.openFiles, initialStage.result)
if err != nil {
return sharedtypes.TransferResult{}, err
}
finalResult := finalizeTransferResult(start, spec, files, result, transportPath)
a.logTransferComplete(spec, finalResult, transferTiming{
start: start,
resolveMs: resolveMs,
initialTransferMs: initialStage.elapsedMs,
collectiveMs: collectiveMs,
usedFFI: initialStage.usedFFI,
})
return finalResult, nil
}
func (a *Adapter) resolveExecuteFiles(ctx context.Context, spec sharedtypes.TransferSpec) ([]sharedtypes.ArtifactFile, int64, error) {
resolveStarted := time.Now()
files, err := resolveFiles(ctx, spec)
if err != nil {
a.logger.Error("resolve transfer files failed", "taskID", spec.TaskID, "mode", spec.TransferMode, "error", err)
return nil, 0, err
}
return files, time.Since(resolveStarted).Milliseconds(), nil
}
func normalizedParallelism(spec sharedtypes.TransferSpec) int {
if spec.Parallelism > 0 {
return int(spec.Parallelism)
}
return 4
}
func (a *Adapter) logTransferStart(spec sharedtypes.TransferSpec, files []sharedtypes.ArtifactFile, parallelism int, resolveMs int64) {
a.logger.Info(
"data-plane transfer started",
"taskID", spec.TaskID,
"mode", spec.TransferMode,
"scenario", classifyTransferScenario(spec),
"parallelism", parallelism,
"sourceCount", len(spec.SourceSegments),
"targetPath", spec.TargetTempPath,
"fileCount", len(files),
"resolveFilesMs", resolveMs,
)
}
func (a *Adapter) tryExecuteRoutedTransfer(
ctx context.Context,
start time.Time,
spec sharedtypes.TransferSpec,
files []sharedtypes.ArtifactFile,
resolveMs int64,
) (routedTransferResult, error) {
if shouldExecuteDirectoryPerFile(spec, files) {
result, err := a.executeDirectoryPerFile(ctx, start, spec, files, resolveMs)
return routedTransferResult{result: result, handled: true}, err
}
if fanoutResult, fanoutTransportPath, fanoutHandled, fanoutErr := a.executeConcurrentSymmetricFanout(ctx, spec, files); fanoutHandled {
if fanoutErr != nil {
return routedTransferResult{handled: true}, fanoutErr
}
a.logger.Info(
"data-plane fanout execute completed",
"taskID", spec.TaskID,
"mode", spec.TransferMode,
"scenario", classifyTransferScenario(spec),
"handledBy", "executeConcurrentSymmetricFanout",
"elapsedMs", time.Since(start).Milliseconds(),
"bytesTransferred", fanoutResult.BytesTransferred,
"transportPath", fanoutTransportPath,
)
return routedTransferResult{
result: finalizeTransferResult(start, spec, files, fanoutResult, fanoutTransportPath),
handled: true,
}, nil
}
return routedTransferResult{}, nil
}
func (a *Adapter) executeInitialTransferStage(
ctx context.Context,
spec sharedtypes.TransferSpec,
files []sharedtypes.ArtifactFile,
parallelism int,
) (initialTransferStageResult, error) {
initialTransferStarted := time.Now()
initialResult, usedFFI, err := a.executeInitialTransfer(ctx, spec, files, parallelism)
if err != nil {
return initialTransferStageResult{}, err
}
openFiles, err := a.prepareCollectiveOpenFiles(spec, files)
if err != nil {
return initialTransferStageResult{}, err
}
return initialTransferStageResult{
result: initialResult,
elapsedMs: time.Since(initialTransferStarted).Milliseconds(),
usedFFI: usedFFI,
openFiles: openFiles,
}, nil
}
func (a *Adapter) prepareCollectiveOpenFiles(
spec sharedtypes.TransferSpec,
files []sharedtypes.ArtifactFile,
) (map[string]*os.File, error) {
if spec.TransferMode != sharedtypes.TransferModePartialPullAllGather {
return map[string]*os.File{}, nil
}
openFiles, err := openPreparedFiles(spec.TargetTempPath, files)
if err != nil {
a.logger.Error("open prepared collective files failed", "taskID", spec.TaskID, "targetPath", spec.TargetTempPath, "error", err)
return nil, err
}
return openFiles, nil
}
func (a *Adapter) executeCollectiveStage(
ctx context.Context,
spec sharedtypes.TransferSpec,
openFiles map[string]*os.File,
initialResult sharedtypes.TransferResult,
) (sharedtypes.TransferResult, int64, sharedtypes.TransportPath, error) {
transportPath := initialResult.TransportPath
if transportPath == "" {
transportPath = a.determineTransportPath()
}
if spec.TransferMode != sharedtypes.TransferModePartialPullAllGather {
return initialResult, 0, transportPath, nil
}
collectiveStarted := time.Now()
collectiveResult, collectivePath, err := a.executePartialPullAllGather(ctx, spec, openFiles)
if err != nil {
a.logger.Error("execute partial pull allgather failed", "taskID", spec.TaskID, "transportPath", transportPath, "error", err)
return sharedtypes.TransferResult{}, 0, "", err
}
return mergeTransferResults(initialResult, collectiveResult), time.Since(collectiveStarted).Milliseconds(), mergeTransportPath(transportPath, collectivePath), nil
}
func (a *Adapter) logTransferComplete(
spec sharedtypes.TransferSpec,
result sharedtypes.TransferResult,
timing transferTiming,
) {
a.logger.Info(
"data-plane transfer completed",
"taskID", spec.TaskID,
"mode", spec.TransferMode,
"scenario", classifyTransferScenario(spec),
"transportPath", result.TransportPath,
"bytes", result.BytesTransferred,
"chunks", result.ChunkCount,
"throughputMBps", result.ThroughputMBps,
"resolveFilesMs", timing.resolveMs,
"initialTransferMs", timing.initialTransferMs,
"usedFFI", timing.usedFFI,
"collectiveMs", timing.collectiveMs,
"totalExecuteMs", time.Since(timing.start).Milliseconds(),
)
}
func finalizeTransferResult(
start time.Time,
spec sharedtypes.TransferSpec,
files []sharedtypes.ArtifactFile,
result sharedtypes.TransferResult,
transportPath sharedtypes.TransportPath,
) sharedtypes.TransferResult {
fileChunkCount := make(map[string]int32, len(files))
uniqueChunks := dedupeChunks(result.TransferredChunks)
for _, chunk := range uniqueChunks {
fileChunkCount[chunk.RelativePath]++
}
manifest := &sharedtypes.SolidifiedManifest{
ArtifactKey: spec.ArtifactKey,
LogicalDigest: spec.LogicalManifest.Digest,
ChunkSizeBytes: spec.ChunkSizeBytes,
GeneratedAt: time.Now().UnixMilli(),
Chunks: uniqueChunks,
}
for _, file := range files {
manifest.Files = append(manifest.Files, sharedtypes.FileDigest{
RelativePath: file.RelativePath,
SizeBytes: file.SizeBytes,
ChunkCount: fileChunkCount[file.RelativePath],
})
}
duration := time.Since(start)
result.TaskID = spec.TaskID
result.TempPath = spec.TargetTempPath
result.StartedAt = start.UnixMilli()
result.FinishedAt = time.Now().UnixMilli()
result.TransportPath = transportPath
result.TransferredChunks = uniqueChunks
result.SolidifiedManifest = manifest
if duration > 0 {
result.ThroughputMBps = float64(result.BytesTransferred) / 1024 / 1024 / duration.Seconds()
}
return result
}
func ensurePreservedTargetsPrepared(targetTempPath string, files []sharedtypes.ArtifactFile) error {
preparedFiles, prepareErr := prepareTargetFilesPreserve(targetTempPath, files)
if prepareErr != nil {
return prepareErr
}
closeOpenFiles(preparedFiles)
return nil
}
func openTransferTargetFiles(targetTempPath string, files []sharedtypes.ArtifactFile, preserveExisting bool) (map[string]*os.File, error) {
if preserveExisting {
return prepareTargetFilesPreserve(targetTempPath, files)
}
return prepareTargetFiles(targetTempPath, files)
}
func openSubSpecTargetFiles(targetTempPath string, files []sharedtypes.ArtifactFile, preserveExisting bool) (map[string]*os.File, error) {
if preserveExisting {
return openPreparedFiles(targetTempPath, files)
}
return prepareTargetFiles(targetTempPath, files)
}
func (a *Adapter) tryExecuteInitialTransferFFI(
spec sharedtypes.TransferSpec,
files []sharedtypes.ArtifactFile,
initialStarted time.Time,
) (sharedtypes.TransferResult, bool) {
if a.forceTCPFallback || !a.enableFFI || shouldBypassDirectPullFFI(spec, files) {
a.logger.Info(
"skipping rust direct-pull FFI",
"taskID", spec.TaskID,
"mode", spec.TransferMode,
"fileCount", len(files),
"forceTCPFallback", a.forceTCPFallback,
)
return sharedtypes.TransferResult{}, false
}
ffiResult, ffiErr := a.tryExecuteDirectPullWithFFI(spec)
if ffiErr == nil {
a.logger.Info(
"rust direct-pull FFI completed",
"taskID", spec.TaskID,
"mode", spec.TransferMode,
"transportPath", ffiResult.TransportPath,
"bytes", ffiResult.BytesTransferred,
"chunks", ffiResult.ChunkCount,
"elapsedMs", time.Since(initialStarted).Milliseconds(),
)
return ffiResult, true
}
a.logger.Warn(
"rust direct-pull FFI unavailable, falling back to Go chunk client",
"taskID", spec.TaskID,
"mode", spec.TransferMode,
"reason", ffiErr,
"ffiAttemptMs", time.Since(initialStarted).Milliseconds(),
)
return sharedtypes.TransferResult{}, false
}
func (a *Adapter) tryExecuteSubSpecWithFFI(
spec sharedtypes.TransferSpec,
started time.Time,
) (sharedtypes.TransferResult, bool) {
if a.forceTCPFallback {
return sharedtypes.TransferResult{}, false
}
if transferUsesHuggingFaceSource(spec) {
return sharedtypes.TransferResult{}, false
}
ffiResult, ffiErr := a.tryExecuteDirectPullWithFFI(spec)
if ffiErr != nil {
return sharedtypes.TransferResult{}, false
}
a.logger.Info(
"transfer sub-spec completed via FFI",
"taskID", spec.TaskID,
"mode", spec.TransferMode,
"sourceCount", len(spec.SourceSegments),
"bytes", ffiResult.BytesTransferred,
"transportPath", ffiResult.TransportPath,
"elapsedMs", time.Since(started).Milliseconds(),
)
return ffiResult, true
}
func (a *Adapter) executeGoTransferJobs(
ctx context.Context,
spec sharedtypes.TransferSpec,
files []sharedtypes.ArtifactFile,
parallelism int,
openFiles map[string]*os.File,
) (jobExecutionPhaseResult, error) {
buildJobsStarted := time.Now()
jobs, err := buildTransferJobs(spec.TransferMode, spec.SourceSegments, files, spec.ChunkSizeBytes)
if err != nil {
a.logger.Error("build transfer jobs failed", "taskID", spec.TaskID, "mode", spec.TransferMode, "error", err)
return jobExecutionPhaseResult{}, err
}
buildJobsMs := time.Since(buildJobsStarted).Milliseconds()
executeJobsStarted := time.Now()
result, err := a.executeJobs(ctx, spec.TargetTempPath, jobs, openFiles, parallelism, effectiveChunkRetryLimit(spec))
if err != nil {
a.logger.Error("execute transfer jobs failed", "taskID", spec.TaskID, "mode", spec.TransferMode, "jobCount", len(jobs), "parallelism", parallelism, "error", err)
return jobExecutionPhaseResult{}, err
}
return jobExecutionPhaseResult{
result: result,
buildJobsMs: buildJobsMs,
executeJobsMs: time.Since(executeJobsStarted).Milliseconds(),
}, nil
}
func (a *Adapter) resolveSubSpecFiles(ctx context.Context, spec sharedtypes.TransferSpec) ([]sharedtypes.ArtifactFile, int64, error) {
resolveStarted := time.Now()
files, err := resolveFiles(ctx, spec)
if err != nil {
a.logger.Error("resolve sub-spec files failed", "taskID", spec.TaskID, "mode", spec.TransferMode, "error", err)
return nil, 0, err
}
return files, time.Since(resolveStarted).Milliseconds(), nil
}
func (a *Adapter) prepareSubSpecTargets(spec sharedtypes.TransferSpec, files []sharedtypes.ArtifactFile) (preparedSubSpecTargets, error) {
if spec.PreserveExisting {
if err := ensurePreservedTargetsPrepared(spec.TargetTempPath, files); err != nil {
a.logger.Error("prepare preserved targets before sub-spec transfer failed", "taskID", spec.TaskID, "targetPath", spec.TargetTempPath, "error", err)
return preparedSubSpecTargets{}, err
}
}
prepareStarted := time.Now()
openFiles, err := openSubSpecTargetFiles(spec.TargetTempPath, files, spec.PreserveExisting)
if err != nil {
a.logger.Error("open sub-spec target files failed", "taskID", spec.TaskID, "targetPath", spec.TargetTempPath, "preserveExisting", spec.PreserveExisting, "error", err)
return preparedSubSpecTargets{}, err
}
return preparedSubSpecTargets{
openFiles: openFiles,
prepareMs: time.Since(prepareStarted).Milliseconds(),
}, nil
}
func (a *Adapter) executeInitialTransfer(ctx context.Context, spec sharedtypes.TransferSpec, files []sharedtypes.ArtifactFile, parallelism int) (sharedtypes.TransferResult, bool, error) {
initialStarted := time.Now()
if spec.PreserveExisting {
if err := ensurePreservedTargetsPrepared(spec.TargetTempPath, files); err != nil {
a.logger.Error("prepare preserved targets before initial transfer failed", "taskID", spec.TaskID, "targetPath", spec.TargetTempPath, "error", err)
return sharedtypes.TransferResult{}, false, err
}
}
if ffiResult, usedFFI := a.tryExecuteInitialTransferFFI(spec, files, initialStarted); usedFFI {
return ffiResult, true, nil
}
prepareStarted := time.Now()
openFiles, err := openTransferTargetFiles(spec.TargetTempPath, files, spec.PreserveExisting)
if err != nil {
a.logger.Error("open initial transfer target files failed", "taskID", spec.TaskID, "targetPath", spec.TargetTempPath, "preserveExisting", spec.PreserveExisting, "error", err)
return sharedtypes.TransferResult{}, false, err
}
defer closeOpenFiles(openFiles)
prepareTargetMs := time.Since(prepareStarted).Milliseconds()
jobPhase, err := a.executeGoTransferJobs(ctx, spec, files, parallelism, openFiles)
if err != nil {
return sharedtypes.TransferResult{}, false, err
}
initialResult := jobPhase.result
initialResult.TransportPath = a.determineTransportPath()
a.logger.Info(
"go initial transfer completed",
"taskID", spec.TaskID,
"mode", spec.TransferMode,
"parallelism", parallelism,
"prepareTargetMs", prepareTargetMs,
"buildJobsMs", jobPhase.buildJobsMs,
"executeJobsMs", jobPhase.executeJobsMs,
"elapsedMs", time.Since(initialStarted).Milliseconds(),
"bytes", initialResult.BytesTransferred,
"transportPath", initialResult.TransportPath,
)
return initialResult, false, nil
}
func (a *Adapter) executePartialPullAllGather(ctx context.Context, spec sharedtypes.TransferSpec, openFiles map[string]*os.File) (sharedtypes.TransferResult, sharedtypes.TransportPath, error) {
if spec.CollectiveSpec.Mode != sharedtypes.CollectiveModeRing || spec.CollectiveSpec.Ring == nil {
return sharedtypes.TransferResult{}, "", fmt.Errorf("partial pull allgather requires ring collective metadata")
}
if result, transportPath, handled, err := a.executeRelayPeerFetch(ctx, spec); handled {
return result, transportPath, err
}
a.logger.Debug(
"evaluating partial pull allgather plan",
"taskID", spec.TaskID,
"selfNode", spec.CollectiveSpec.Ring.SelfNode,
"prevNode", spec.CollectiveSpec.Ring.PrevNode,
"nextNode", spec.CollectiveSpec.Ring.NextNode,
"prevEndpoint", spec.CollectiveSpec.Ring.PrevEndpoint,
"nextEndpoint", spec.CollectiveSpec.Ring.NextEndpoint,
"peerCount", len(spec.CollectiveSpec.Peers),
"sourceSegmentCount", len(spec.SourceSegments),
)
result, transportPath, err := a.executeOwnerFetchWithDirectPull(ctx, spec)
if err == nil {
return result, transportPath, nil
}
a.logger.Warn("owner-fetch direct pull unavailable for partial pull allgather, falling back to legacy collective path", "taskID", spec.TaskID, "reason", err)
if spec.CollectiveSpec.Ring.PrevEndpoint == "" || spec.CollectiveSpec.Ring.NextEndpoint == "" {
a.logger.Warn(
"ring endpoints missing for partial pull allgather, using owner-fetch fallback",
"taskID", spec.TaskID,
"prevEndpoint", spec.CollectiveSpec.Ring.PrevEndpoint,
"nextEndpoint", spec.CollectiveSpec.Ring.NextEndpoint,
)
return a.executeOwnerFetchFallback(ctx, spec, openFiles)
}
return a.executeRingCollectiveOverAPI(ctx, spec, openFiles)
}
func (a *Adapter) executeTransferSubSpec(ctx context.Context, spec sharedtypes.TransferSpec) (sharedtypes.TransferResult, error) {
started := time.Now()
files, resolveMs, err := a.resolveSubSpecFiles(ctx, spec)
if err != nil {
return sharedtypes.TransferResult{}, err
}
if ffiResult, usedFFI := a.tryExecuteSubSpecWithFFI(spec, started); usedFFI {
return ffiResult, nil
}
targets, err := a.prepareSubSpecTargets(spec, files)
if err != nil {
return sharedtypes.TransferResult{}, err
}
defer closeOpenFiles(targets.openFiles)
jobPhase, err := a.executeGoTransferJobs(ctx, spec, files, max(1, int(spec.Parallelism)), targets.openFiles)
if err != nil {
return sharedtypes.TransferResult{}, err
}
result := jobPhase.result
result.TransportPath = a.determineTransportPath()
a.logger.Info(
"transfer sub-spec completed via Go fallback",
"taskID", spec.TaskID,
"mode", spec.TransferMode,
"sourceCount", len(spec.SourceSegments),
"resolveFilesMs", resolveMs,
"prepareTargetMs", targets.prepareMs,
"buildJobsMs", jobPhase.buildJobsMs,
"executeJobsMs", jobPhase.executeJobsMs,
"elapsedMs", time.Since(started).Milliseconds(),
"bytes", result.BytesTransferred,
"transportPath", result.TransportPath,
)
return result, nil
}