* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* openFuyao is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
* http://license.coscl.org.cn/MulanPSL2
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
*/
use std::cell::UnsafeCell;
use std::collections::{BTreeMap, HashMap, VecDeque};
use std::fs::{self, File, OpenOptions};
use std::os::unix::fs::FileExt;
use std::path::{Component, Path, PathBuf};
use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering};
use std::sync::mpsc::{self, Receiver, SyncSender};
use std::sync::{Arc, Mutex, OnceLock};
use std::thread::JoinHandle;
use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
use base64::engine::general_purpose::STANDARD as BASE64_STANDARD;
use base64::Engine;
use reqwest::blocking::Client;
use serde::{Deserialize, Serialize};
use crate::transport::rdma::{QueuePairConfig, RdmaEndpointInfo, RdmaError, RdmaTransport, RegisteredMemory};
#[derive(Debug, Clone)]
pub struct DirectPullRequest {
pub device_name: Option<String>,
pub completion_entries: i32,
pub queue_depth: u32,
pub inline_bytes: u32,
pub registration_bytes: usize,
}
impl Default for DirectPullRequest {
fn default() -> Self {
Self {
device_name: None,
completion_entries: 256,
queue_depth: 128,
inline_bytes: 0,
registration_bytes: 4 * 1024 * 1024,
}
}
}
#[derive(Debug, Clone)]
pub struct DirectExecutor {
pub device_name: String,
pub registered_bytes: usize,
}
impl DirectExecutor {
pub fn prepare(request: &DirectPullRequest) -> Result<Self, RdmaError> {
let transport = RdmaTransport::open(
request.device_name.as_deref(),
QueuePairConfig {
completion_entries: request.completion_entries,
send_wr: request.queue_depth,
recv_wr: request.queue_depth,
send_sge: 1,
recv_sge: 1,
inline_bytes: request.inline_bytes,
sq_sig_all: true,
},
)?;
let mut scratch = vec![0_u8; request.registration_bytes];
let _registered = transport.register_memory(&mut scratch)?;
let device_name = transport.device_name().to_string();
Ok(Self {
device_name,
registered_bytes: request.registration_bytes,
})
}
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
struct TransferSpec {
#[serde(rename = "taskID", default)]
task_id: String,
artifact_key: String,
transfer_mode: String,
logical_manifest: LogicalManifest,
#[serde(default)]
collective_spec: CollectiveSpec,
source_segments: Vec<SourceSegmentPlan>,
target_temp_path: String,
#[serde(default)]
preserve_existing: bool,
#[serde(default)]
force_sync_close: bool,
#[serde(rename = "enableChunkCRC32C", default)]
enable_chunk_crc32c: bool,
chunk_size_bytes: i64,
#[serde(default)]
parallelism: i32,
#[serde(default)]
retry_limit: i32,
#[serde(default = "default_timeout_seconds")]
timeout_seconds: i32,
}
fn default_timeout_seconds() -> i32 {
0
}
#[derive(Debug, Deserialize, Clone)]
#[serde(rename_all = "camelCase")]
struct LogicalManifest {
artifact_key: String,
root_path: String,
chunk_size_bytes: i64,
digest: String,
generated_at: i64,
files: Vec<ArtifactFile>,
}
#[derive(Debug, Deserialize, Clone)]
#[serde(rename_all = "camelCase")]
struct ArtifactFile {
relative_path: String,
size_bytes: i64,
#[serde(default)]
kind: String,
#[serde(default)]
chunkable: bool,
#[serde(default)]
required: bool,
}
#[derive(Debug, Deserialize, Clone)]
#[serde(rename_all = "camelCase")]
struct SourceSegmentPlan {
#[serde(rename = "sourceID", default)]
source_id: String,
source_endpoint: SourceEndpoint,
#[serde(default)]
byte_ranges: Vec<ByteRange>,
#[serde(default = "default_source_weight")]
weight: i32,
}
#[derive(Debug, Deserialize, Clone)]
#[serde(rename_all = "camelCase")]
struct SourceEndpoint {
#[serde(rename = "sourceID", default)]
source_id: String,
#[serde(default)]
source_type: String,
#[serde(default)]
endpoint: String,
path: String,
#[serde(default)]
node_name: Option<String>,
#[serde(default)]
#[serde(rename = "relayRDMA", alias = "relayRdma")]
relay_rdma: Option<RelayRdmaHint>,
}
#[derive(Debug, Deserialize, Clone)]
#[serde(rename_all = "camelCase")]
struct ByteRange {
relative_path: String,
start: i64,
end: i64,
#[serde(default)]
relay_offset: i64,
}
#[derive(Debug, Deserialize, Serialize, Clone)]
#[serde(rename_all = "camelCase")]
struct RelayRdmaHint {
#[serde(rename = "sessionID")]
session_id: String,
connection_info: RDMAConnectionInfo,
#[serde(default)]
persistent: bool,
}
#[derive(Debug, Deserialize, Clone, Default)]
#[serde(rename_all = "camelCase")]
struct CollectiveSpec {
#[serde(default)]
#[serde(rename = "sessionID")]
session_id: String,
#[serde(default)]
mode: String,
#[serde(default)]
ring: Option<RingPeerPlan>,
#[serde(default)]
peers: Vec<CollectivePeerPlan>,
}
#[derive(Debug, Deserialize, Clone, Default)]
#[serde(rename_all = "camelCase")]
struct RingPeerPlan {
#[serde(default)]
self_node: String,
#[serde(default)]
self_endpoint: String,
#[serde(default)]
rank: i32,
#[serde(default)]
world_size: i32,
}
#[derive(Debug, Deserialize, Clone, Default)]
#[serde(rename_all = "camelCase")]
struct CollectivePeerPlan {
#[serde(default)]
node_name: String,
#[serde(default)]
endpoint: String,
#[serde(default)]
rank: i32,
#[serde(default)]
staging_path: String,
#[serde(default)]
owned_ranges: Vec<ByteRange>,
}
#[derive(Debug, Clone)]
struct TransferJob {
source_id: String,
source: SourceEndpoint,
relative_path: String,
offset: i64,
length: i64,
source_offset: i64,
}
#[derive(Debug, Clone)]
struct ManifestChunk {
relative_path: String,
offset: i64,
length: i64,
}
const RDMA_DIRECT_JOB_BYTES: i64 = 64 * 1024 * 1024;
const RDMA_READ_PIPELINE_CHUNK_BYTES: usize = 1 * 1024 * 1024;
const RDMA_READ_PIPELINE_DEPTH: usize = 16;
const RDMA_LOCAL_SLOT_COUNT: usize = 3;
const LARGE_SINGLE_SOURCE_GROUP_BYTES: i64 = 1536 * 1024 * 1024;
const HUGE_SINGLE_SOURCE_TOTAL_BYTES: i64 = 128 * 1024 * 1024 * 1024;
const RELAY_PUBLISH_BATCH_BYTES: i64 = 512 * 1024 * 1024;
const SYMMETRIC_RELAY_PUBLISH_BATCH_BYTES: i64 = 256 * 1024 * 1024;
const SYMMETRIC_PERSISTENT_RELAY_LANES: usize = 2;
const RELAY_ACK_POLL_INTERVAL_MS: u64 = 20;
const RELAY_PUBLISH_GROUP_TARGET: usize = 2;
const DIRECT_STRIPED_PROBE_CHUNK_BYTES: i64 = 4 * 1024 * 1024;
const DIRECT_STRIPED_PROBE_READ_ROUNDS: usize = 3;
const DIRECT_STRIPED_SINGLE_SOURCE_THRESHOLD_MBPS: i64 = 900;
static INJECTED_SOURCE_CHUNK_CRC_MISMATCH: AtomicBool = AtomicBool::new(false);
fn source_chunk_retry_limit(spec: &TransferSpec) -> usize {
if spec.retry_limit > 0 {
spec.retry_limit as usize
} else {
1
}
}
fn inject_chunk_crc_mismatch_enabled() -> bool {
std::env::var("WD_INJECT_CHUNK_CRC_MISMATCH_ONCE")
.map(|raw| matches!(raw.trim().to_ascii_lowercase().as_str(), "1" | "true" | "yes" | "on"))
.unwrap_or(false)
}
fn verify_source_crc32c(payload: &[u8], expected_crc32c: Option<&str>) -> Result<(), String> {
let Some(expected) = expected_crc32c.filter(|value| !value.is_empty()) else {
return Ok(());
};
let checksum = crc32c::crc32c(payload);
let mut actual = encode_crc32c(checksum);
if inject_chunk_crc_mismatch_enabled()
&& !INJECTED_SOURCE_CHUNK_CRC_MISMATCH.swap(true, Ordering::SeqCst)
&& !actual.is_empty()
{
actual = format!("{}{}", if &actual[0..1] == "0" { "1" } else { "0" }, &actual[1..]);
}
if actual != expected {
return Err(format!("source chunk crc mismatch: expect={} actual={}", expected, actual));
}
Ok(())
}
fn default_source_weight() -> i32 {
1
}
fn rdma_single_writer_queue_depth() -> usize {
std::env::var("WD_RDMA_SINGLE_WRITER_QUEUE_DEPTH")
.ok()
.and_then(|raw| raw.trim().parse::<usize>().ok())
.map(|value| value.clamp(2, RDMA_LOCAL_SLOT_COUNT))
.unwrap_or(0)
}
fn rdma_shared_writer_queue_depth() -> usize {
std::env::var("WD_RDMA_SHARED_WRITER_QUEUE_DEPTH")
.ok()
.and_then(|raw| raw.trim().parse::<usize>().ok())
.map(|value| value.clamp(2, RDMA_LOCAL_SLOT_COUNT))
.unwrap_or(0)
}
fn force_tcp_fallback_enabled() -> bool {
static FORCE_TCP: OnceLock<bool> = OnceLock::new();
*FORCE_TCP.get_or_init(|| {
std::env::var("WD_FORCE_TCP_FALLBACK")
.map(|value| value.eq_ignore_ascii_case("true") || value == "1")
.unwrap_or(false)
})
}
#[derive(Debug, Clone)]
struct RelayRootPublisher {
task_id: String,
session_id: String,
source_node: String,
endpoint: String,
}
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
struct PushCollectiveChunkRequestOwned {
#[serde(rename = "taskID")]
task_id: String,
#[serde(rename = "sessionID")]
session_id: String,
iteration: i32,
expected_chunks: i32,
chunk: TransferredChunk,
#[serde(skip_serializing_if = "Option::is_none")]
data: Option<String>,
#[serde(default)]
relay_offset: i64,
#[serde(skip_serializing_if = "Option::is_none")]
#[serde(rename = "relayRDMA")]
relay_rdma: Option<RelayRdmaHint>,
transport_path: String,
source_node: String,
}
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
struct ListCollectiveChunksRequestOwned {
#[serde(rename = "taskID")]
task_id: String,
#[serde(rename = "sessionID")]
session_id: String,
iteration: i32,
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
struct ListCollectiveChunksResponseOwned {
#[serde(default)]
#[serde(rename = "acknowledgedIteration")]
acknowledged_iteration: i32,
}
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
struct CollectiveStepRequestOwned {
#[serde(rename = "taskID")]
task_id: String,
#[serde(rename = "sessionID")]
session_id: String,
iteration: i32,
#[serde(default)]
barrier_only: bool,
#[serde(default)]
acknowledge_only: bool,
}
#[derive(Debug)]
struct PersistentRelayLane {
session_id: String,
capacity: usize,
last_iteration: i32,
}
fn normalize_agent_endpoint(raw: &str) -> Result<String, String> {
if raw.is_empty() {
return Err("agent endpoint is required".to_string());
}
if raw.contains("://") {
return Ok(raw.trim_end_matches('/').to_string());
}
if raw.contains(':') && !raw.starts_with('[') && raw.matches(':').count() == 1 {
return Ok(format!("http://{}", raw.trim_end_matches('/')));
}
Ok(format!(
"http://{}:{}",
raw.trim_end_matches('/'),
default_agent_port()
))
}
fn default_agent_port() -> u16 {
std::env::var("AGENT_PORT")
.ok()
.and_then(|raw| raw.trim().parse::<u16>().ok())
.filter(|port| *port > 0)
.unwrap_or(18080)
}
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
struct TransferResult {
#[serde(rename = "taskID")]
task_id: String,
temp_path: String,
bytes_transferred: i64,
chunk_count: i32,
succeeded_chunks: i32,
failed_chunks: i32,
#[serde(rename = "throughputMBps")]
throughput_m_bps: f64,
transport_path: String,
started_at: i64,
finished_at: i64,
transferred_chunks: Vec<TransferredChunk>,
solidified_manifest: Option<SolidifiedManifest>,
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
enum PayloadTransport {
Rdma,
TcpFallback,
}
impl PayloadTransport {
fn as_result_path(all_rdma: bool, any_rdma: bool) -> String {
if all_rdma {
"RDMA".to_string()
} else if any_rdma {
"MIXED".to_string()
} else {
"TCP_FALLBACK".to_string()
}
}
}
#[derive(Debug, Serialize, Clone)]
#[serde(rename_all = "camelCase")]
struct TransferredChunk {
#[serde(rename = "chunkID")]
chunk_id: String,
file_path: String,
relative_path: String,
offset: i64,
size: i64,
#[serde(rename = "crc32c")]
crc32c: String,
#[serde(rename = "sourceID")]
source_id: String,
}
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
struct SolidifiedManifest {
artifact_key: String,
logical_digest: String,
chunk_size_bytes: i64,
generated_at: i64,
files: Vec<FileDigest>,
chunks: Vec<TransferredChunk>,
}
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
struct FileDigest {
relative_path: String,
size_bytes: i64,
chunk_count: i32,
}
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
struct StatExportRequest<'a> {
root_path: &'a str,
relative_path: &'a str,
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
struct StatExportResponse {
size_bytes: i64,
}
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
struct ReadChunkRequest<'a> {
root_path: &'a str,
relative_path: &'a str,
offset: i64,
length: i64,
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
struct ReadChunkResponse {
data: String,
#[serde(default)]
crc32c: String,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
#[serde(rename_all = "camelCase")]
struct RDMAConnectionInfo {
#[serde(default)]
device_name: String,
qpn: u32,
psn: u32,
#[serde(default)]
lid: u16,
gid: String,
rkey: u32,
addr: u64,
length: u64,
}
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
struct OpenRDMAExportRequest<'a> {
root_path: &'a str,
relative_path: &'a str,
offset: i64,
length: i64,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
struct OpenRDMAExportResponse {
#[serde(rename = "sessionID")]
session_id: String,
transport_path: String,
connection_info: RDMAConnectionInfo,
#[serde(default)]
expected_crc32c: String,
}
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
struct ConnectRDMAExportRequest {
#[serde(rename = "sessionID")]
session_id: String,
connection_info: RDMAConnectionInfo,
timeout_seconds: i32,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
struct ConnectRDMAExportResponse {
#[serde(rename = "sessionID")]
session_id: String,
transport_path: String,
}
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
struct UpdateRDMAExportRequest {
#[serde(rename = "sessionID")]
session_id: String,
offset: i64,
length: i64,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
struct UpdateRDMAExportResponse {
#[serde(rename = "sessionID")]
session_id: String,
transport_path: String,
connection_info: RDMAConnectionInfo,
#[serde(default)]
expected_crc32c: String,
}
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
struct CloseRDMAExportRequest {
#[serde(rename = "sessionID")]
session_id: String,
}
struct ExportSlot {
registered: RegisteredMemory,
buffer: Vec<u8>,
endpoint: RdmaEndpointInfo,
}
struct ExportSession {
source_path: Option<PathBuf>,
source_file: Option<File>,
slots: [ExportSlot; 2],
transport: RdmaTransport,
active_slot: usize,
mutable: bool,
}
static EXPORT_SESSIONS: OnceLock<Mutex<HashMap<String, Arc<Mutex<ExportSession>>>>> = OnceLock::new();
static SHARED_WRITE_QUEUES: OnceLock<Mutex<HashMap<String, Arc<SharedWriteQueue>>>> = OnceLock::new();
static SESSION_COUNTER: AtomicU64 = AtomicU64::new(1);
fn create_relay_export_session(outcomes: &[RelayPublishedChunk]) -> Result<Option<(RelayRdmaHint, HashMap<String, i64>)>, String> {
if force_tcp_fallback_enabled() {
return Ok(None);
}
if outcomes.is_empty() {
return Ok(None);
}
let total_started = Instant::now();
let total_bytes = outcomes
.iter()
.map(|outcome| outcome.chunk.size.max(0))
.sum::<i64>();
if total_bytes <= 0 {
return Ok(None);
}
let copy_started = Instant::now();
let mut buffer = vec![0_u8; total_bytes as usize];
let mut relay_offsets = HashMap::<String, i64>::with_capacity(outcomes.len());
let mut cursor = 0_usize;
for outcome in outcomes {
let size = outcome.chunk.size as usize;
if size == 0 {
continue;
}
if outcome.payload.len() != size {
return Err(format!(
"relay payload size mismatch for {}@{}: expected {}, got {}",
outcome.chunk.relative_path,
outcome.chunk.offset,
size,
outcome.payload.len()
));
}
buffer[cursor..cursor + size].copy_from_slice(&outcome.payload);
relay_offsets.insert(outcome.chunk.chunk_id.clone(), cursor as i64);
cursor += size;
}
let copy_ms = copy_started.elapsed().as_millis();
let transport_started = Instant::now();
let transport = RdmaTransport::open_default().map_err(|err| format!("open relay rdma transport: {err}"))?;
let transport_ms = transport_started.elapsed().as_millis();
let mut primary = buffer;
let mut standby = vec![0_u8; 1];
let register_started = Instant::now();
let primary_registered = transport
.register_memory(&mut primary)
.map_err(|err| format!("register relay rdma memory: {err}"))?;
let standby_registered = transport
.register_memory(&mut standby)
.map_err(|err| format!("register relay standby memory: {err}"))?;
let register_ms = register_started.elapsed().as_millis();
let endpoint_started = Instant::now();
let primary_endpoint = transport
.local_endpoint(&primary_registered)
.map_err(|err| format!("query relay rdma endpoint: {err}"))?;
let standby_endpoint = transport
.local_endpoint(&standby_registered)
.map_err(|err| format!("query relay standby rdma endpoint: {err}"))?;
let endpoint_ms = endpoint_started.elapsed().as_millis();
let session_id = format!("rdma-relay-{}-{}", now_millis(), SESSION_COUNTER.fetch_add(1, Ordering::Relaxed));
let session = ExportSession {
source_path: None,
source_file: None,
transport,
slots: [
ExportSlot {
buffer: primary,
registered: primary_registered,
endpoint: primary_endpoint.clone(),
},
ExportSlot {
buffer: standby,
registered: standby_registered,
endpoint: standby_endpoint,
},
],
active_slot: 0,
mutable: false,
};
let insert_started = Instant::now();
export_sessions()
.lock()
.map_err(|_| "lock export sessions".to_string())?
.insert(session_id.clone(), Arc::new(Mutex::new(session)));
let insert_ms = insert_started.elapsed().as_millis();
eprintln!(
"relay export session created session={} chunks={} bytes={} copy_ms={} transport_ms={} register_ms={} endpoint_ms={} insert_ms={} total_ms={}",
session_id,
outcomes.len(),
total_bytes,
copy_ms,
transport_ms,
register_ms,
endpoint_ms,
insert_ms,
total_started.elapsed().as_millis()
);
Ok(Some((
RelayRdmaHint {
session_id,
connection_info: rdma_endpoint_to_wire("", &primary_endpoint),
persistent: false,
},
relay_offsets,
)))
}
fn create_empty_relay_export_session(capacity: usize) -> Result<RelayRdmaHint, String> {
if force_tcp_fallback_enabled() {
return Err("relay lane disabled under forced tcp fallback".to_string());
}
if capacity == 0 {
return Err("relay lane capacity must be positive".to_string());
}
let mut primary = vec![0_u8; capacity];
let mut standby = vec![0_u8; capacity];
let transport = RdmaTransport::open_default().map_err(|err| format!("open relay lane rdma transport: {err}"))?;
let primary_registered = transport
.register_memory(&mut primary)
.map_err(|err| format!("register relay lane primary memory: {err}"))?;
let standby_registered = transport
.register_memory(&mut standby)
.map_err(|err| format!("register relay lane standby memory: {err}"))?;
let primary_endpoint = transport
.local_endpoint(&primary_registered)
.map_err(|err| format!("query relay lane primary endpoint: {err}"))?;
let standby_endpoint = transport
.local_endpoint(&standby_registered)
.map_err(|err| format!("query relay lane standby endpoint: {err}"))?;
let session_id = format!("rdma-relay-{}-{}", now_millis(), SESSION_COUNTER.fetch_add(1, Ordering::Relaxed));
let session = ExportSession {
source_path: None,
source_file: None,
transport,
slots: [
ExportSlot {
buffer: primary,
registered: primary_registered,
endpoint: primary_endpoint.clone(),
},
ExportSlot {
buffer: standby,
registered: standby_registered,
endpoint: standby_endpoint,
},
],
active_slot: 1,
mutable: true,
};
export_sessions()
.lock()
.map_err(|_| "lock export sessions".to_string())?
.insert(session_id.clone(), Arc::new(Mutex::new(session)));
eprintln!(
"relay lane created session={} capacity={} bytes",
session_id,
capacity
);
Ok(RelayRdmaHint {
session_id,
connection_info: rdma_endpoint_to_wire("", &primary_endpoint),
persistent: true,
})
}
fn update_relay_export_session(
session_id: &str,
outcomes: &[RelayPublishedChunk],
) -> Result<(RelayRdmaHint, HashMap<String, i64>), String> {
let total_bytes = outcomes
.iter()
.map(|outcome| outcome.chunk.size.max(0))
.sum::<i64>();
if total_bytes <= 0 {
return Err(format!("relay session {} has no bytes to publish", session_id));
}
let session = {
let sessions = export_sessions().lock().map_err(|_| "lock export sessions".to_string())?;
sessions
.get(session_id)
.cloned()
.ok_or_else(|| format!("rdma relay export session {} not found", session_id))?
};
let mut session = session
.lock()
.map_err(|_| format!("lock relay export session {}", session_id))?;
if !session.mutable {
return Err(format!("rdma relay export session {} is not mutable", session_id));
}
let next_slot = 1 - session.active_slot;
let mut relay_offsets = HashMap::<String, i64>::with_capacity(outcomes.len());
let mut cursor = 0_usize;
let endpoint = {
let slot = &mut session.slots[next_slot];
if total_bytes as usize > slot.buffer.len() {
return Err(format!(
"relay session {} capacity {} smaller than {} bytes",
session_id,
slot.buffer.len(),
total_bytes
));
}
for outcome in outcomes {
let size = outcome.chunk.size as usize;
if size == 0 {
continue;
}
if outcome.payload.len() != size {
return Err(format!(
"relay payload size mismatch for {}@{}: expected {}, got {}",
outcome.chunk.relative_path,
outcome.chunk.offset,
size,
outcome.payload.len()
));
}
slot.buffer[cursor..cursor + size].copy_from_slice(&outcome.payload);
relay_offsets.insert(outcome.chunk.chunk_id.clone(), cursor as i64);
cursor += size;
}
slot.endpoint.length = total_bytes as u64;
slot.endpoint.clone()
};
session.active_slot = next_slot;
Ok((
RelayRdmaHint {
session_id: session_id.to_string(),
connection_info: rdma_endpoint_to_wire("", &endpoint),
persistent: true,
},
relay_offsets,
))
}
fn resolve_relay_root_publisher(spec: &TransferSpec) -> Option<RelayRootPublisher> {
if spec.transfer_mode != "PARTIAL_PULL_ALLGATHER" {
return None;
}
if spec.collective_spec.mode != "RING" {
return None;
}
let ring = spec.collective_spec.ring.as_ref()?;
if ring.world_size < 2 || ring.self_endpoint.is_empty() {
return None;
}
let self_peer = spec
.collective_spec
.peers
.iter()
.find(|peer| peer.rank == ring.rank)?;
if !self_peer.owned_ranges.iter().any(|range| range.end > range.start) {
return None;
}
Some(RelayRootPublisher {
task_id: spec.task_id.clone(),
session_id: if spec.collective_spec.session_id.is_empty() {
spec.task_id.clone()
} else {
spec.collective_spec.session_id.clone()
},
source_node: if ring.self_node.is_empty() {
spec.task_id.clone()
} else {
ring.self_node.clone()
},
endpoint: ring.self_endpoint.clone(),
})
}
fn is_ring_relay_collective(spec: &TransferSpec) -> bool {
if spec.transfer_mode != "PARTIAL_PULL_ALLGATHER" {
return false;
}
if spec.collective_spec.mode != "RING" {
return false;
}
let Some(ring) = spec.collective_spec.ring.as_ref() else {
return false;
};
if ring.world_size < 2 {
return false;
}
spec.collective_spec
.peers
.iter()
.filter(|peer| peer.owned_ranges.iter().any(|range| range.end > range.start))
.count()
>= 2
}
fn publish_relay_group_metadata(
client: &Client,
publisher: &RelayRootPublisher,
iteration: i32,
expected_chunks: i32,
outcomes: &[RelayPublishedChunk],
) -> Result<(), String> {
publish_relay_group_metadata_with_hint(
client,
publisher,
iteration,
expected_chunks,
outcomes,
create_relay_export_session(outcomes)?,
)
}
fn publish_relay_group_metadata_with_hint(
client: &Client,
publisher: &RelayRootPublisher,
iteration: i32,
expected_chunks: i32,
outcomes: &[RelayPublishedChunk],
relay_export: Option<(RelayRdmaHint, HashMap<String, i64>)>,
) -> Result<(), String> {
let endpoint = normalize_agent_endpoint(&publisher.endpoint)?;
let transport_path = if outcomes
.iter()
.all(|outcome| outcome.transport == PayloadTransport::Rdma)
{
"RDMA"
} else {
"TCP_FALLBACK"
};
for outcome in outcomes {
let relay_offset = relay_export
.as_ref()
.and_then(|(_, offsets)| offsets.get(&outcome.chunk.chunk_id).copied())
.unwrap_or_default();
let relay_rdma = relay_export.as_ref().map(|(hint, _)| hint.clone());
let response = client
.post(format!("{}/v1/collectives/chunks/push", endpoint))
.json(&PushCollectiveChunkRequestOwned {
task_id: publisher.task_id.clone(),
session_id: publisher.session_id.clone(),
iteration,
expected_chunks,
chunk: outcome.chunk.clone(),
data: None,
relay_offset,
relay_rdma,
transport_path: transport_path.to_string(),
source_node: publisher.source_node.clone(),
})
.send()
.map_err(|err| format!("publish relay collective chunk {}@{}: {err}", outcome.chunk.relative_path, outcome.chunk.offset))?;
if !response.status().is_success() {
return Err(format!(
"publish relay collective chunk {}@{} failed with status {}",
outcome.chunk.relative_path,
outcome.chunk.offset,
response.status()
));
}
}
Ok(())
}
fn wait_for_collective_ack(
client: &Client,
publisher: &RelayRootPublisher,
iteration: i32,
timeout: Duration,
) -> Result<(), String> {
if iteration <= 0 {
return Ok(());
}
let endpoint = normalize_agent_endpoint(&publisher.endpoint)?;
let deadline = Instant::now() + timeout;
loop {
let response = client
.post(format!("{}/v1/collectives/chunks/list", endpoint))
.json(&ListCollectiveChunksRequestOwned {
task_id: publisher.task_id.clone(),
session_id: publisher.session_id.clone(),
iteration,
})
.send()
.map_err(|err| format!("poll relay ack iteration {}: {err}", iteration))?;
if response.status().is_success() {
let body: ListCollectiveChunksResponseOwned = response
.json()
.map_err(|err| format!("decode relay ack iteration {}: {err}", iteration))?;
if body.acknowledged_iteration >= iteration {
return Ok(());
}
}
if Instant::now() >= deadline {
return Err(format!("timed out waiting for relay ack iteration {}", iteration));
}
std::thread::sleep(Duration::from_millis(RELAY_ACK_POLL_INTERVAL_MS));
}
}
#[derive(Clone)]
struct RelayBatchPublisher {
client: Client,
publisher: RelayRootPublisher,
group_target: usize,
max_inflight: usize,
batch_bytes: i64,
ack_timeout: Duration,
lanes: Option<Arc<Vec<Mutex<PersistentRelayLane>>>>,
state: Arc<Mutex<RelayPublishState>>,
dispatch_lock: Arc<Mutex<()>>,
}
#[derive(Default)]
struct RelayPublishState {
next_iteration: i32,
pending_groups: Vec<Vec<RelayPublishedChunk>>,
pending_bytes: i64,
inflight: Vec<JoinHandle<Result<(), String>>>,
}
impl RelayBatchPublisher {
fn new(
client: Client,
publisher: RelayRootPublisher,
_target_temp_path: String,
group_target: usize,
max_inflight: usize,
batch_bytes: i64,
ack_timeout: Duration,
persistent_lanes: usize,
) -> Self {
let normalized_batch_bytes = batch_bytes.max(RDMA_DIRECT_JOB_BYTES);
let lanes = if persistent_lanes > 0 {
let mut created = Vec::with_capacity(persistent_lanes);
for _ in 0..persistent_lanes {
if let Ok(hint) = create_empty_relay_export_session(normalized_batch_bytes as usize) {
created.push(Mutex::new(PersistentRelayLane {
session_id: hint.session_id,
capacity: normalized_batch_bytes as usize,
last_iteration: 0,
}));
}
}
if created.is_empty() {
None
} else {
Some(Arc::new(created))
}
} else {
None
};
Self {
client,
publisher,
group_target: group_target.max(1),
max_inflight: max_inflight.max(1),
batch_bytes: normalized_batch_bytes,
ack_timeout,
lanes,
state: Arc::new(Mutex::new(RelayPublishState {
next_iteration: 1,
pending_groups: Vec::new(),
pending_bytes: 0,
inflight: Vec::new(),
})),
dispatch_lock: Arc::new(Mutex::new(())),
}
}
fn batch_bytes(&self) -> i64 {
self.batch_bytes
}
fn publish(&self, outcomes: Vec<RelayPublishedChunk>) -> Result<(), String> {
if outcomes.is_empty() {
return Ok(());
}
if let Some((iteration, groups, bytes)) = self.enqueue(outcomes, false)? {
self.dispatch_iteration(iteration, groups, bytes)?;
}
Ok(())
}
fn flush(&self) -> Result<(), String> {
if let Some((iteration, groups, bytes)) = self.enqueue(Vec::new(), true)? {
self.dispatch_iteration(iteration, groups, bytes)?;
}
self.wait_for_all()
}
fn enqueue(
&self,
outcomes: Vec<RelayPublishedChunk>,
flush: bool,
) -> Result<Option<(i32, Vec<Vec<RelayPublishedChunk>>, i64)>, String> {
let mut state = self
.state
.lock()
.map_err(|_| "lock relay batch publisher state".to_string())?;
if !outcomes.is_empty() {
state.pending_bytes += outcomes.iter().map(|outcome| outcome.chunk.size).sum::<i64>();
state.pending_groups.push(outcomes);
}
let ready = if flush {
!state.pending_groups.is_empty()
} else {
state.pending_groups.len() >= self.group_target
};
if !ready {
return Ok(None);
}
let iteration = state.next_iteration;
state.next_iteration += 1;
let bytes = state.pending_bytes;
state.pending_bytes = 0;
let groups = std::mem::take(&mut state.pending_groups);
Ok(Some((iteration, groups, bytes)))
}
fn dispatch_iteration(
&self,
iteration: i32,
groups: Vec<Vec<RelayPublishedChunk>>,
bytes: i64,
) -> Result<(), String> {
let _dispatch_guard = self
.dispatch_lock
.lock()
.map_err(|_| "lock relay batch publisher dispatch".to_string())?;
self.reap_inflight(true)?;
let client = self.client.clone();
let publisher = self.publisher.clone();
let ack_timeout = self.ack_timeout;
let lanes = self.lanes.clone();
let handle = std::thread::spawn(move || {
publish_relay_iteration(&client, &publisher, iteration, groups, bytes, ack_timeout, lanes)
});
let mut state = self
.state
.lock()
.map_err(|_| "lock relay batch publisher state".to_string())?;
state.inflight.push(handle);
drop(state);
self.reap_inflight(false)
}
fn reap_inflight(&self, force_oldest: bool) -> Result<(), String> {
let mut handles = Vec::new();
{
let mut state = self
.state
.lock()
.map_err(|_| "lock relay batch publisher state".to_string())?;
let mut index = 0;
while index < state.inflight.len() {
if state.inflight[index].is_finished() {
handles.push(state.inflight.swap_remove(index));
} else {
index += 1;
}
}
if force_oldest && handles.is_empty() && state.inflight.len() >= self.max_inflight {
handles.push(state.inflight.remove(0));
}
}
for handle in handles {
match handle.join() {
Ok(Ok(())) => {}
Ok(Err(err)) => return Err(err),
Err(_) => return Err("relay publish worker panicked".to_string()),
}
}
Ok(())
}
fn wait_for_all(&self) -> Result<(), String> {
let handles = {
let mut state = self
.state
.lock()
.map_err(|_| "lock relay batch publisher state".to_string())?;
std::mem::take(&mut state.inflight)
};
for handle in handles {
match handle.join() {
Ok(Ok(())) => {}
Ok(Err(err)) => return Err(err),
Err(_) => return Err("relay publish worker panicked".to_string()),
}
}
Ok(())
}
}
fn publish_relay_iteration(
client: &Client,
publisher: &RelayRootPublisher,
iteration: i32,
groups: Vec<Vec<RelayPublishedChunk>>,
bytes: i64,
ack_timeout: Duration,
lanes: Option<Arc<Vec<Mutex<PersistentRelayLane>>>>,
) -> Result<(), String> {
let started = Instant::now();
let chunk_count = groups.iter().map(|group| group.len()).sum::<usize>();
eprintln!(
"relay batch publish iteration={} groups={} chunks={} bytes={} first_chunk={}",
iteration,
groups.len(),
chunk_count,
bytes,
groups
.first()
.and_then(|group| group.first())
.map(|outcome| format!("{}@{}", outcome.chunk.relative_path, outcome.chunk.offset))
.unwrap_or_else(|| "none".to_string())
);
for (index, group) in groups.into_iter().enumerate() {
if let Some(lanes) = lanes.as_ref() {
let lane = &lanes[index % lanes.len()];
let (session_id, previous_iteration) = {
let guard = lane
.lock()
.map_err(|_| "lock persistent relay lane".to_string())?;
(guard.session_id.clone(), guard.last_iteration)
};
let lane_available = previous_iteration <= 0
|| wait_for_collective_ack(client, publisher, previous_iteration, Duration::ZERO).is_ok();
if lane_available {
let relay_export = Some(update_relay_export_session(&session_id, &group)?);
publish_relay_group_metadata_with_hint(
client,
publisher,
iteration,
chunk_count as i32,
&group,
relay_export,
)?;
let mut guard = lane
.lock()
.map_err(|_| "lock persistent relay lane".to_string())?;
guard.last_iteration = iteration;
} else {
eprintln!(
"persistent relay lane busy session={} previous_iteration={} publish_iteration={} fallback=ephemeral",
session_id,
previous_iteration,
iteration
);
publish_relay_group_metadata(
client,
publisher,
iteration,
chunk_count as i32,
&group,
)?;
}
} else {
publish_relay_group_metadata(
client,
publisher,
iteration,
chunk_count as i32,
&group,
)?;
}
}
eprintln!(
"relay batch publish completed iteration={} chunks={} bytes={} total_ms={}",
iteration,
chunk_count,
bytes,
started.elapsed().as_millis()
);
Ok(())
}
fn relay_publish_group_target(spec: &TransferSpec) -> usize {
if is_ring_relay_collective(spec) {
if let Ok(value) = std::env::var("WD_RDMA_RELAY_PUBLISH_GROUP_TARGET") {
if let Ok(parsed) = value.parse::<usize>() {
if parsed > 0 {
return parsed;
}
}
}
return 1;
}
RELAY_PUBLISH_GROUP_TARGET
}
fn relay_publish_batch_bytes(spec: &TransferSpec) -> i64 {
if is_ring_relay_collective(spec) {
if let Ok(value) = std::env::var("WD_RDMA_SYMMETRIC_RELAY_PUBLISH_BATCH_MB") {
if let Ok(parsed) = value.parse::<i64>() {
if parsed > 0 {
return (parsed * 1024 * 1024).max(RDMA_DIRECT_JOB_BYTES);
}
}
}
return SYMMETRIC_RELAY_PUBLISH_BATCH_BYTES;
}
RELAY_PUBLISH_BATCH_BYTES
}
fn relay_publish_max_inflight(spec: &TransferSpec) -> usize {
if let Ok(value) = std::env::var("WD_RDMA_RELAY_PUBLISH_INFLIGHT") {
if let Ok(parsed) = value.parse::<usize>() {
if parsed > 0 {
return parsed;
}
}
}
if is_ring_relay_collective(spec) {
return 2;
}
1
}
fn relay_persistent_lane_count(spec: &TransferSpec) -> usize {
if !is_ring_relay_collective(spec) {
return 0;
}
if let Ok(value) = std::env::var("WD_RDMA_PERSISTENT_RELAY_LANES") {
if let Ok(parsed) = value.parse::<usize>() {
if parsed == 0 {
return 0;
}
return parsed.min(8);
}
}
SYMMETRIC_PERSISTENT_RELAY_LANES
}
fn should_force_sync_close_for_collective(spec: &TransferSpec) -> bool {
is_ring_relay_collective(spec)
}
fn shared_queue_key(spec: &TransferSpec, target_temp_path: &str) -> String {
if !spec.preserve_existing {
return target_temp_path.to_string();
}
if spec.logical_manifest.files.len() == 1 {
return format!(
"{}#file:{}",
target_temp_path, spec.logical_manifest.files[0].relative_path
);
}
format!("{}#{}", target_temp_path, spec.task_id)
}
fn should_serialize_huge_single_source(spec: &TransferSpec, files: &[ArtifactFile], source_count: usize) -> bool {
if source_count != 1 {
return false;
}
if is_ring_relay_collective(spec) || resolve_relay_root_publisher(spec).is_some() {
return false;
}
if spec.transfer_mode == "DIRECT_STRIPED" {
return false;
}
if files.len() <= 1 {
return false;
}
let total_bytes = files.iter().map(|file| file.size_bytes).sum::<i64>();
total_bytes >= HUGE_SINGLE_SOURCE_TOTAL_BYTES
}
pub fn execute_json(request: &str) -> Result<String, String> {
let spec: TransferSpec = serde_json::from_str(request)
.map_err(|err| format!("parse direct pull request: {err}"))?;
let start_at = now_millis();
let instant = Instant::now();
let files = resolve_files(&spec)?;
let open_files = Arc::new(prepare_target_files(&spec.target_temp_path, &files, spec.preserve_existing)?);
let client = Client::builder()
.timeout(Duration::from_secs(timeout_seconds(&spec)))
.build()
.map_err(|err| format!("build HTTP client: {err}"))?;
let read_timeout = timeout_seconds(&spec) as i32;
let target_temp_path = spec.target_temp_path.clone();
let relay_root_publisher = resolve_relay_root_publisher(&spec);
let relay_group_target = relay_publish_group_target(&spec);
let relay_publish_inflight = relay_publish_max_inflight(&spec);
let relay_publish_batch_bytes = relay_publish_batch_bytes(&spec);
let relay_persistent_lanes = relay_persistent_lane_count(&spec);
let source_segments = if spec.transfer_mode == "DIRECT_STRIPED" {
prepare_direct_striped_segments(&client, &spec, &files, read_timeout)?
} else {
spec.source_segments.clone()
};
let jobs = build_jobs_with_segments(&spec, &files, &source_segments)?;
let mut worker_count = desired_worker_count(&spec, jobs.len(), source_segments.len());
let force_sync_close = spec.force_sync_close
|| should_serialize_huge_single_source(&spec, &files, source_segments.len())
|| should_force_sync_close_for_collective(&spec);
if force_sync_close {
eprintln!(
"direct forcing sync close files={} bytes={} requested_parallelism={} original_workers={} explicit_flag={}",
files.len(),
files.iter().map(|file| file.size_bytes).sum::<i64>(),
spec.parallelism,
worker_count,
spec.force_sync_close
);
worker_count = 1;
}
let mut bytes_transferred = 0_i64;
let mut all_rdma = true;
let mut any_rdma = false;
let mut transferred_chunks = Vec::with_capacity(jobs.len());
let relay_batch_publisher = relay_root_publisher.clone().map(|publisher| {
RelayBatchPublisher::new(
client.clone(),
publisher,
target_temp_path.clone(),
relay_group_target,
relay_publish_inflight,
relay_publish_batch_bytes,
Duration::from_secs(timeout_seconds(&spec) as u64),
relay_persistent_lanes,
)
});
let shared_queue_key = shared_queue_key(&spec, &target_temp_path);
let retry_limit = source_chunk_retry_limit(&spec);
if worker_count <= 1 || jobs.len() <= 1 {
let group_outcomes = execute_job_group(
jobs,
client.clone(),
Arc::clone(&open_files),
read_timeout,
retry_limit,
spec.enable_chunk_crc32c,
target_temp_path.clone(),
shared_queue_key.clone(),
relay_batch_publisher.clone(),
force_sync_close,
)?;
for outcome in group_outcomes {
if outcome.transport == PayloadTransport::Rdma {
any_rdma = true;
} else {
all_rdma = false;
}
bytes_transferred += outcome.chunk.size;
transferred_chunks.push(outcome.chunk);
}
} else {
let shared_results = Arc::new(Mutex::new(Vec::<JobOutcome>::new()));
let first_error = Arc::new(Mutex::new(None::<String>));
let stop = Arc::new(AtomicBool::new(false));
let job_groups = build_job_groups(&spec.transfer_mode, jobs, worker_count, source_segments.len());
std::thread::scope(|scope| {
for group in job_groups {
let client = client.clone();
let results = Arc::clone(&shared_results);
let first_error = Arc::clone(&first_error);
let stop = Arc::clone(&stop);
let open_files = Arc::clone(&open_files);
let target_temp_path = target_temp_path.clone();
let shared_queue_key = shared_queue_key.clone();
let relay_batch_publisher = relay_batch_publisher.clone();
let force_sync_close = force_sync_close;
scope.spawn(move || {
if stop.load(Ordering::Relaxed) {
return;
}
let first_relative_path = group
.first()
.map(|job| job.relative_path.clone())
.unwrap_or_else(|| "unknown".to_string());
match execute_job_group(
group,
client.clone(),
open_files,
read_timeout,
retry_limit,
spec.enable_chunk_crc32c,
target_temp_path.clone(),
shared_queue_key,
relay_batch_publisher,
force_sync_close,
) {
Ok(group_outcomes) => {
eprintln!(
"direct worker-group finished relative_path={} outcomes={}",
first_relative_path,
group_outcomes.len()
);
let mut guard = results.lock().unwrap();
for outcome in group_outcomes {
guard.push(outcome);
}
}
Err(err) => {
stop.store(true, Ordering::Relaxed);
let mut guard = first_error.lock().unwrap();
if guard.is_none() {
*guard = Some(err);
}
}
}
});
}
});
if let Some(err) = first_error.lock().unwrap().take() {
return Err(err);
}
eprintln!("direct execute_json worker-groups joined outcomes_pending={}", shared_results.lock().unwrap().len());
for outcome in shared_results.lock().unwrap().drain(..) {
if outcome.transport == PayloadTransport::Rdma {
any_rdma = true;
} else {
all_rdma = false;
}
bytes_transferred += outcome.chunk.size;
transferred_chunks.push(outcome.chunk);
}
}
if let Some(publisher) = relay_batch_publisher.as_ref() {
publisher.flush()?;
}
drop(open_files);
let deduped_chunks = dedupe_chunks(transferred_chunks);
let mut file_chunk_counts = BTreeMap::<String, i32>::new();
for chunk in &deduped_chunks {
*file_chunk_counts.entry(chunk.relative_path.clone()).or_default() += 1;
}
let manifest_files = files
.iter()
.map(|file| FileDigest {
relative_path: file.relative_path.clone(),
size_bytes: file.size_bytes,
chunk_count: *file_chunk_counts.get(&file.relative_path).unwrap_or(&0),
})
.collect::<Vec<_>>();
let result = TransferResult {
task_id: spec.task_id,
temp_path: spec.target_temp_path,
bytes_transferred,
chunk_count: deduped_chunks.len() as i32,
succeeded_chunks: deduped_chunks.len() as i32,
failed_chunks: 0,
throughput_m_bps: throughput(bytes_transferred, instant.elapsed()),
transport_path: PayloadTransport::as_result_path(all_rdma, any_rdma),
started_at: start_at,
finished_at: now_millis(),
transferred_chunks: deduped_chunks.clone(),
solidified_manifest: Some(SolidifiedManifest {
artifact_key: spec.artifact_key,
logical_digest: spec.logical_manifest.digest,
chunk_size_bytes: spec.chunk_size_bytes,
generated_at: now_millis(),
files: manifest_files,
chunks: deduped_chunks,
}),
};
serde_json::to_string(&result).map_err(|err| format!("serialize direct pull result: {err}"))
}
fn resolve_files(spec: &TransferSpec) -> Result<Vec<ArtifactFile>, String> {
if !spec.logical_manifest.files.is_empty() {
return Ok(spec.logical_manifest.files.clone());
}
let segment = spec
.source_segments
.first()
.ok_or_else(|| "no source segments configured".to_string())?;
let relative_path = Path::new(&segment.source_endpoint.path)
.file_name()
.map(|value| value.to_string_lossy().into_owned())
.ok_or_else(|| "cannot derive relative path from source path".to_string())?;
let size_bytes = stat_source(segment, "")?;
Ok(vec![ArtifactFile {
relative_path,
size_bytes,
kind: "SAFETENSORS".to_string(),
chunkable: true,
required: true,
}])
}
fn build_jobs(spec: &TransferSpec, files: &[ArtifactFile]) -> Result<Vec<TransferJob>, String> {
build_jobs_with_segments(spec, files, &spec.source_segments)
}
fn build_jobs_with_segments(
spec: &TransferSpec,
files: &[ArtifactFile],
source_segments: &[SourceSegmentPlan],
) -> Result<Vec<TransferJob>, String> {
let mut jobs = Vec::new();
match spec.transfer_mode.as_str() {
"SINGLE_SOURCE_DIRECT" => {
for segment in source_segments {
if segment.byte_ranges.is_empty() {
if files.is_empty() {
return Err("logical manifest files are required for direct pull".to_string());
}
append_full_file_jobs(&mut jobs, segment, files, spec.chunk_size_bytes);
continue;
}
append_range_jobs(&mut jobs, segment);
}
}
"DIRECT_STRIPED" => {
jobs = build_direct_striped_jobs(source_segments, files, spec.chunk_size_bytes)?;
}
"PARTIAL_PULL_ALLGATHER" => {
for segment in source_segments {
if segment.byte_ranges.is_empty() {
return Err(format!("{} mode requires explicit byte ranges", spec.transfer_mode));
}
append_range_jobs(&mut jobs, segment);
}
}
other => return Err(format!("unsupported transfer mode {other}")),
}
if matches!(spec.transfer_mode.as_str(), "SINGLE_SOURCE_DIRECT" | "DIRECT_STRIPED")
&& spec.chunk_size_bytes > 0
&& spec.chunk_size_bytes < RDMA_DIRECT_JOB_BYTES
{
jobs = coalesce_jobs(jobs, RDMA_DIRECT_JOB_BYTES);
}
Ok(jobs)
}
fn build_direct_striped_jobs(
source_segments: &[SourceSegmentPlan],
files: &[ArtifactFile],
chunk_size_bytes: i64,
) -> Result<Vec<TransferJob>, String> {
if source_segments.is_empty() {
return Err("DIRECT_STRIPED mode requires at least one source segment".to_string());
}
if files.is_empty() {
return Err("logical manifest files are required for DIRECT_STRIPED mode".to_string());
}
let chunks = build_manifest_chunks(files, chunk_size_bytes);
if chunks.is_empty() {
return Ok(Vec::new());
}
Ok(assign_manifest_chunks_weighted(chunks, source_segments))
}
fn build_direct_striped_probe_chunk(files: &[ArtifactFile]) -> Option<ManifestChunk> {
build_manifest_chunks(files, DIRECT_STRIPED_PROBE_CHUNK_BYTES)
.into_iter()
.next()
}
fn build_manifest_chunks(files: &[ArtifactFile], chunk_size_bytes: i64) -> Vec<ManifestChunk> {
let mut chunks = Vec::new();
for file in files {
if file.size_bytes <= 0 {
continue;
}
let chunk_size = if chunk_size_bytes <= 0 || chunk_size_bytes > file.size_bytes {
file.size_bytes
} else {
chunk_size_bytes
};
let mut offset = 0_i64;
while offset < file.size_bytes {
let length = std::cmp::min(chunk_size, file.size_bytes - offset);
chunks.push(ManifestChunk {
relative_path: file.relative_path.clone(),
offset,
length,
});
offset += length;
}
}
chunks
}
fn assign_manifest_chunks_weighted(
chunks: Vec<ManifestChunk>,
source_segments: &[SourceSegmentPlan],
) -> Vec<TransferJob> {
let weights = source_segments
.iter()
.map(|segment| i128::from(segment.weight.max(1)))
.collect::<Vec<_>>();
let mut assigned_bytes = vec![0_i128; source_segments.len()];
let mut jobs = Vec::with_capacity(chunks.len());
for chunk in chunks {
let mut best_idx = 0_usize;
for idx in 1..source_segments.len() {
let best_left = assigned_bytes[idx] * weights[best_idx];
let best_right = assigned_bytes[best_idx] * weights[idx];
if best_left < best_right
|| (best_left == best_right
&& (weights[idx] > weights[best_idx]
|| (weights[idx] == weights[best_idx] && assigned_bytes[idx] < assigned_bytes[best_idx])))
{
best_idx = idx;
}
}
let segment = &source_segments[best_idx];
assigned_bytes[best_idx] += i128::from(chunk.length);
jobs.push(TransferJob {
source_id: segment.source_id.clone(),
source: segment.source_endpoint.clone(),
relative_path: chunk.relative_path,
offset: chunk.offset,
length: chunk.length,
source_offset: chunk.offset,
});
}
jobs.sort_by(|left, right| {
left.source_id
.cmp(&right.source_id)
.then(left.relative_path.cmp(&right.relative_path))
.then(left.offset.cmp(&right.offset))
});
jobs
}
fn prepare_direct_striped_segments(
client: &Client,
spec: &TransferSpec,
files: &[ArtifactFile],
timeout_seconds: i32,
) -> Result<Vec<SourceSegmentPlan>, String> {
let calibrated = calibrate_direct_striped_segments(client, spec, files, timeout_seconds)?;
Ok(select_direct_striped_segments(spec, calibrated))
}
fn calibrate_direct_striped_segments(
client: &Client,
spec: &TransferSpec,
files: &[ArtifactFile],
timeout_seconds: i32,
) -> Result<Vec<SourceSegmentPlan>, String> {
if spec.source_segments.len() <= 1 {
return Ok(spec.source_segments.clone());
}
let total_bytes = files.iter().map(|file| file.size_bytes).sum::<i64>();
if total_bytes <= 512 * 1024 * 1024 {
return Ok(spec.source_segments.clone());
}
let Some(probe_chunk) = build_direct_striped_probe_chunk(files) else {
return Ok(spec.source_segments.clone());
};
let calibration_started = Instant::now();
let mut calibrated = spec.source_segments.clone();
let mut measurements = Vec::with_capacity(calibrated.len());
for segment in &mut calibrated {
let probe_job = TransferJob {
source_id: segment.source_id.clone(),
source: segment.source_endpoint.clone(),
relative_path: probe_chunk.relative_path.clone(),
offset: probe_chunk.offset,
length: probe_chunk.length,
source_offset: probe_chunk.offset,
};
let measured = measure_source_probe_weight(client, &probe_job, timeout_seconds)
.unwrap_or_else(|err| {
eprintln!(
"direct striped probe fallback source={} endpoint={} reason={}",
segment.source_id, segment.source_endpoint.endpoint, err
);
i64::from(segment.weight.max(1))
})
.max(1);
segment.weight = measured.min(i64::from(i32::MAX)) as i32;
measurements.push(format!("{}={}", segment.source_id, segment.weight));
}
eprintln!(
"direct striped calibration sources={} probe_bytes={} rounds={} total_ms={}",
calibrated.len(),
probe_chunk.length,
DIRECT_STRIPED_PROBE_READ_ROUNDS,
calibration_started.elapsed().as_millis()
);
eprintln!("direct striped calibrated source weights {}", measurements.join(","));
Ok(calibrated)
}
fn striped_single_source_worker_budget(spec: &TransferSpec) -> usize {
let requested = if spec.parallelism > 0 {
spec.parallelism as usize
} else {
8
};
requested.max(1).min(8)
}
fn select_direct_striped_segments(
spec: &TransferSpec,
segments: Vec<SourceSegmentPlan>,
) -> Vec<SourceSegmentPlan> {
if segments.len() <= 1 {
return segments;
}
let predicted_single_source_budget = striped_single_source_worker_budget(spec) as i64;
let Some(best) = segments.iter().max_by_key(|segment| segment.weight.max(1)) else {
return segments;
};
let predicted_single_source_throughput =
i64::from(best.weight.max(1)) * predicted_single_source_budget;
if predicted_single_source_throughput < DIRECT_STRIPED_SINGLE_SOURCE_THRESHOLD_MBPS {
return segments;
}
eprintln!(
"direct striped selecting single source source={} predicted_single_source_mbps={} threshold_mbps={}",
best.source_id, predicted_single_source_throughput, DIRECT_STRIPED_SINGLE_SOURCE_THRESHOLD_MBPS
);
vec![best.clone()]
}
fn measure_source_probe_weight(
client: &Client,
job: &TransferJob,
timeout_seconds: i32,
) -> Result<i64, String> {
if job.source.endpoint.is_empty() {
let read_started = Instant::now();
for _ in 0..DIRECT_STRIPED_PROBE_READ_ROUNDS {
let _ = read_local_chunk(&job.source.path, &job.relative_path, job.source_offset, job.length)?;
}
let elapsed = read_started.elapsed();
let bytes_measured = job.length * DIRECT_STRIPED_PROBE_READ_ROUNDS as i64;
if elapsed.is_zero() {
return Ok(bytes_measured.max(1));
}
return Ok(
((bytes_measured as f64 / 1024_f64 / 1024_f64) / elapsed.as_secs_f64()).round() as i64,
);
}
let mut session = ReusableRdmaSession::open(client, job, timeout_seconds)?;
session.prepare_current(client, job)?;
let read_started = Instant::now();
for _ in 0..DIRECT_STRIPED_PROBE_READ_ROUNDS {
let _ = session.read(0, job.length as usize)?;
}
let elapsed = read_started.elapsed();
let _ = close_rdma_session(client, &session.endpoint, &session.session_id);
let bytes_measured = job.length * DIRECT_STRIPED_PROBE_READ_ROUNDS as i64;
if elapsed.is_zero() {
return Ok(bytes_measured.max(1));
}
Ok(((bytes_measured as f64 / 1024_f64 / 1024_f64) / elapsed.as_secs_f64()).round() as i64)
}
#[derive(Debug, Clone)]
struct JobOutcome {
chunk: TransferredChunk,
transport: PayloadTransport,
read_ms: u128,
write_ms: u128,
total_ms: u128,
}
#[derive(Clone)]
struct RelayPublishedChunk {
chunk: TransferredChunk,
transport: PayloadTransport,
payload: Vec<u8>,
}
struct WriteCompletion {
outcome: JobOutcome,
relay_chunk: Option<RelayPublishedChunk>,
}
enum PayloadLease {
Owned(Vec<u8>),
Borrowed { slot_index: usize, length: usize },
}
struct PendingRefresh {
job: TransferJob,
handle: JoinHandle<Result<(RdmaEndpointInfo, Option<String>), String>>,
}
struct LocalReadSlot {
registered: RegisteredMemory,
buffer: UnsafeCell<Vec<u8>>,
}
unsafe impl Send for LocalReadSlot {}
unsafe impl Sync for LocalReadSlot {}
impl LocalReadSlot {
fn slice(&self, length: usize) -> &[u8] {
let buffer = unsafe { &*self.buffer.get() };
&buffer[..length]
}
}
enum PendingPayload {
Owned { payload: Vec<u8> },
Borrowed {
slot: Arc<LocalReadSlot>,
length: usize,
},
}
struct PendingWrite {
job: TransferJob,
payload: PendingPayload,
transport: PayloadTransport,
read_ms: u128,
started: Instant,
}
struct SharedWriteTask {
pending: PendingWrite,
enable_chunk_crc32c: bool,
capture_relay_payload: bool,
completion_tx: SyncSender<Result<WriteCompletion, String>>,
}
struct SharedWriteQueue {
pending_tx: SyncSender<SharedWriteTask>,
active_users: AtomicUsize,
}
struct ReusableRdmaReader {
session: Option<ReusableRdmaSession>,
}
impl ReusableRdmaReader {
fn new() -> Self {
Self { session: None }
}
fn read_payload(
&mut self,
client: &Client,
job: &TransferJob,
next_job: Option<&TransferJob>,
allow_prefetch: bool,
local_slot: usize,
timeout_seconds: i32,
retry_limit: usize,
) -> Result<(PayloadLease, PayloadTransport), String> {
if job.source.endpoint.is_empty() {
for attempt in 0..=retry_limit {
let payload = read_local_chunk(&job.source.path, &job.relative_path, job.source_offset, job.length)?;
let expected = Some(encode_crc32c(crc32c::crc32c(&payload)));
match verify_source_crc32c(&payload, expected.as_deref()) {
Ok(()) => return Ok((PayloadLease::Owned(payload), PayloadTransport::TcpFallback)),
Err(err) if attempt < retry_limit => {
eprintln!("retry local chunk {}@{} attempt={} reason={}", job.relative_path, job.offset, attempt + 1, err);
}
Err(err) => return Err(err),
}
}
return Err("local retries exhausted".to_string());
}
if force_tcp_fallback_enabled() {
for attempt in 0..=retry_limit {
let (payload, expected_crc32c) = read_payload_via_http(client, job)?;
match verify_source_crc32c(&payload, expected_crc32c.as_deref()) {
Ok(()) => return Ok((PayloadLease::Owned(payload), PayloadTransport::TcpFallback)),
Err(err) if attempt < retry_limit => {
eprintln!("retry http chunk {}@{} attempt={} reason={}", job.relative_path, job.offset, attempt + 1, err);
}
Err(err) => return Err(err),
}
}
return Err("http retries exhausted".to_string());
}
match self.read_payload_via_rdma(
client,
job,
next_job,
allow_prefetch,
local_slot,
timeout_seconds,
retry_limit,
) {
Ok(length) => Ok((
PayloadLease::Borrowed {
slot_index: local_slot,
length,
},
PayloadTransport::Rdma,
)),
Err(err) => {
eprintln!(
"rdma fallback for {}@{} ({} bytes): {}",
job.relative_path, job.offset, job.length, err
);
self.close(client);
for attempt in 0..=retry_limit {
let (payload, expected_crc32c) = read_payload_via_http(client, job)?;
match verify_source_crc32c(&payload, expected_crc32c.as_deref()) {
Ok(()) => return Ok((PayloadLease::Owned(payload), PayloadTransport::TcpFallback)),
Err(retry_err) if attempt < retry_limit => {
eprintln!("retry fallback http chunk {}@{} attempt={} reason={}", job.relative_path, job.offset, attempt + 1, retry_err);
}
Err(retry_err) => return Err(retry_err),
}
}
Err("http fallback retries exhausted".to_string())
}
}
}
fn read_payload_via_rdma(
&mut self,
client: &Client,
job: &TransferJob,
next_job: Option<&TransferJob>,
allow_prefetch: bool,
local_slot: usize,
timeout_seconds: i32,
retry_limit: usize,
) -> Result<usize, String> {
for attempt in 0..=retry_limit {
let needs_new_session = self
.session
.as_ref()
.map(|session| !session.matches(job))
.unwrap_or(true);
if needs_new_session {
self.close(client);
self.session = Some(ReusableRdmaSession::open(client, job, timeout_seconds)?);
}
let session = self.session.as_mut().ok_or_else(|| "rdma session not initialized".to_string())?;
session.prepare_current(client, job)?;
if allow_prefetch {
if let Some(next_job) = next_job.filter(|next| session.matches(next)) {
session.start_prefetch(client, next_job)?;
}
}
let length = session.read(local_slot, job.length as usize)?;
let slot = session.local_slot(local_slot);
match verify_source_crc32c(slot.slice(length), session.expected_crc32c.as_deref()) {
Ok(()) => return Ok(length),
Err(err) if attempt < retry_limit => {
eprintln!("retry rdma chunk {}@{} attempt={} reason={}", job.relative_path, job.offset, attempt + 1, err);
session.pending_refresh = None;
continue;
}
Err(err) => return Err(err),
}
}
Err("rdma retries exhausted".to_string())
}
fn close(&mut self, client: &Client) {
if let Some(mut session) = self.session.take() {
session.drain_pending();
if !session.relay_rdma {
let _ = close_rdma_session(client, &session.endpoint, &session.session_id);
}
}
}
fn close_background(&mut self, client: Client) {
if let Some(mut session) = self.session.take() {
session.drain_pending();
if session.relay_rdma {
return;
}
let endpoint = session.endpoint.clone();
let session_id = session.session_id.clone();
drop(session);
std::thread::spawn(move || {
let _ = close_rdma_session(&client, &endpoint, &session_id);
});
}
}
fn slot_handle(&self, slot_index: usize) -> Option<Arc<LocalReadSlot>> {
self.session.as_ref().map(|session| session.local_slot(slot_index))
}
}
fn join_pending_refresh(session_id: &str, pending: PendingRefresh) -> Result<(RdmaEndpointInfo, Option<String>), String> {
pending
.handle
.join()
.map_err(|_| format!("rdma export prefetch thread panicked for session {}", session_id))?
}
fn update_rdma_export_session(
client: &Client,
endpoint: &str,
session_id: &str,
job: &TransferJob,
) -> Result<(RdmaEndpointInfo, Option<String>), String> {
let response = client
.post(format!("{}/v1/exports/rdma/update", endpoint))
.json(&UpdateRDMAExportRequest {
session_id: session_id.to_string(),
offset: job.source_offset,
length: job.length,
})
.send()
.map_err(|err| format!("update rdma export session {}: {err}", session_id))?;
if !response.status().is_success() {
return Err(format!(
"update rdma export failed for session {} with status {}",
session_id,
response.status()
));
}
let body: UpdateRDMAExportResponse = response
.json()
.map_err(|err| format!("decode rdma export update response {}: {err}", session_id))?;
if body.transport_path != "RDMA" {
return Err(format!("rdma export session {} did not stay in RDMA mode", session_id));
}
let mut remote = rdma_endpoint_from_wire(&body.connection_info)?;
remote.length = job.length as u64;
Ok((remote, if body.expected_crc32c.is_empty() { None } else { Some(body.expected_crc32c) }))
}
struct ReusableRdmaSession {
source_id: String,
endpoint: String,
root_path: String,
relative_path: String,
capacity: i64,
session_id: String,
local_slots: Vec<Arc<LocalReadSlot>>,
transport: RdmaTransport,
base_remote: RdmaEndpointInfo,
remote: RdmaEndpointInfo,
expected_crc32c: Option<String>,
loaded_job: TransferJob,
pending_refresh: Option<PendingRefresh>,
relay_rdma: bool,
}
impl ReusableRdmaSession {
fn open(client: &Client, job: &TransferJob, timeout_seconds: i32) -> Result<Self, String> {
let started = Instant::now();
let relay_rdma = job.source.relay_rdma.clone();
let mut open_ms = 0_u128;
let (session_id, connection_info, expected_crc32c) = if let Some(relay) = relay_rdma.as_ref() {
(relay.session_id.clone(), relay.connection_info.clone(), None)
} else {
let open_started = Instant::now();
let open_response = client
.post(format!("{}/v1/exports/rdma/open", job.source.endpoint))
.json(&OpenRDMAExportRequest {
root_path: &job.source.path,
relative_path: &job.relative_path,
offset: job.source_offset,
length: job.length,
})
.send()
.map_err(|err| format!("open rdma export session {}@{}: {err}", job.relative_path, job.source_offset))?;
if !open_response.status().is_success() {
return Err(format!(
"open rdma export failed for {}@{} with status {}",
job.relative_path,
job.source_offset,
open_response.status()
));
}
let open_body: OpenRDMAExportResponse = open_response
.json()
.map_err(|err| format!("decode rdma export open response {}@{}: {err}", job.relative_path, job.source_offset))?;
if open_body.transport_path != "RDMA" {
return Err(format!(
"rdma export session {} not opened in RDMA mode",
open_body.session_id
));
}
open_ms = open_started.elapsed().as_millis();
(
open_body.session_id,
open_body.connection_info,
if open_body.expected_crc32c.is_empty() { None } else { Some(open_body.expected_crc32c) },
)
};
let local_setup_started = Instant::now();
let transport = RdmaTransport::open_default().map_err(|err| format!("open local rdma transport: {err}"))?;
let mut local_slots = Vec::with_capacity(RDMA_LOCAL_SLOT_COUNT);
for slot_index in 0..RDMA_LOCAL_SLOT_COUNT {
let mut slot_buffer = vec![0_u8; job.length as usize];
let slot_registered = transport
.register_memory(&mut slot_buffer)
.map_err(|err| format!("register local rdma memory slot {}: {err}", slot_index))?;
local_slots.push(Arc::new(LocalReadSlot {
buffer: UnsafeCell::new(slot_buffer),
registered: slot_registered,
}));
}
let local_endpoint = transport
.local_endpoint(&local_slots[0].registered)
.map_err(|err| format!("query local rdma endpoint: {err}"))?;
let local_setup_ms = local_setup_started.elapsed().as_millis();
let connect_remote_started = Instant::now();
let connect_response = client
.post(format!("{}/v1/exports/rdma/connect", job.source.endpoint))
.json(&ConnectRDMAExportRequest {
session_id: session_id.clone(),
connection_info: rdma_endpoint_to_wire(&transport.device_name().to_string(), &local_endpoint),
timeout_seconds,
})
.send()
.map_err(|err| format!("connect rdma export session {}: {err}", session_id))?;
if !connect_response.status().is_success() {
let _ = close_rdma_session(client, &job.source.endpoint, &session_id);
return Err(format!(
"connect rdma export failed for session {} with status {}",
session_id,
connect_response.status()
));
}
let connect_body: ConnectRDMAExportResponse = connect_response
.json()
.map_err(|err| format!("decode rdma export connect response {}: {err}", session_id))?;
if connect_body.transport_path != "RDMA" {
let _ = close_rdma_session(client, &job.source.endpoint, &session_id);
return Err(format!(
"rdma export session {} did not stay in RDMA mode",
session_id
));
}
let connect_remote_ms = connect_remote_started.elapsed().as_millis();
let remote = rdma_endpoint_from_wire(&connection_info)?;
let connect_local_started = Instant::now();
if let Err(err) = transport.connect_rc(&remote) {
let _ = close_rdma_session(client, &job.source.endpoint, &session_id);
return Err(format!("connect local qp to remote session {}: {err}", session_id));
}
let connect_local_ms = connect_local_started.elapsed().as_millis();
eprintln!(
"rdma session open source={} session={} relative_path={} offset={} length={} relay_rdma={} open_ms={} local_setup_ms={} connect_remote_ms={} connect_local_ms={} total_ms={}",
job.source_id,
session_id,
job.relative_path,
job.source_offset,
job.length,
relay_rdma.is_some(),
open_ms,
local_setup_ms,
connect_remote_ms,
connect_local_ms,
started.elapsed().as_millis()
);
Ok(Self {
source_id: job.source_id.clone(),
endpoint: job.source.endpoint.clone(),
root_path: job.source.path.clone(),
relative_path: job.relative_path.clone(),
capacity: job.length,
session_id,
transport,
local_slots,
base_remote: remote.clone(),
remote,
expected_crc32c,
loaded_job: job.clone(),
pending_refresh: None,
relay_rdma: relay_rdma.is_some(),
})
}
fn matches(&self, job: &TransferJob) -> bool {
self.source_id == job.source_id
&& self.endpoint == job.source.endpoint
&& self.root_path == job.source.path
&& self.relative_path == job.relative_path
&& self.relay_rdma == job.source.relay_rdma.is_some()
&& job.length <= self.capacity
}
fn prepare_current(&mut self, client: &Client, job: &TransferJob) -> Result<(), String> {
if self.relay_rdma {
if self.loaded_job.source_offset != job.source_offset || self.loaded_job.length != job.length {
self.remote = self.base_remote.clone();
self.remote.addr = self
.base_remote
.addr
.checked_add(job.source_offset as u64)
.ok_or_else(|| format!("relay rdma address overflow for session {}", self.session_id))?;
self.remote.length = job.length as u64;
self.loaded_job = job.clone();
}
return Ok(());
}
if same_job(&self.loaded_job, job) {
return Ok(());
}
if let Some(pending) = self.pending_refresh.take() {
if same_job(&pending.job, job) {
let wait_started = Instant::now();
let (remote, expected_crc32c) = join_pending_refresh(self.session_id.as_str(), pending)?;
self.remote = remote;
self.expected_crc32c = expected_crc32c;
eprintln!(
"rdma session reuse-pending source={} session={} relative_path={} offset={} length={} wait_ms={}",
self.source_id,
self.session_id,
job.relative_path,
job.source_offset,
job.length,
wait_started.elapsed().as_millis()
);
self.loaded_job = job.clone();
return Ok(());
}
let (remote, expected_crc32c) = join_pending_refresh(self.session_id.as_str(), pending)?;
self.remote = remote;
self.expected_crc32c = expected_crc32c;
}
let update_started = Instant::now();
let (remote, expected_crc32c) = update_rdma_export_session(client, &self.endpoint, &self.session_id, job)?;
self.remote = remote;
self.expected_crc32c = expected_crc32c;
eprintln!(
"rdma session update source={} session={} relative_path={} offset={} length={} update_ms={}",
self.source_id,
self.session_id,
job.relative_path,
job.source_offset,
job.length,
update_started.elapsed().as_millis()
);
self.loaded_job = job.clone();
Ok(())
}
fn start_prefetch(&mut self, client: &Client, job: &TransferJob) -> Result<(), String> {
if self.relay_rdma {
return Ok(());
}
if self.pending_refresh.as_ref().is_some_and(|pending| same_job(&pending.job, job)) {
return Ok(());
}
if let Some(pending) = self.pending_refresh.take() {
let (remote, expected_crc32c) = join_pending_refresh(self.session_id.as_str(), pending)?;
self.remote = remote;
self.expected_crc32c = expected_crc32c;
}
let client = client.clone();
let endpoint = self.endpoint.clone();
let session_id = self.session_id.clone();
let pending_job = job.clone();
let update_job = job.clone();
let handle =
std::thread::spawn(move || update_rdma_export_session(&client, &endpoint, &session_id, &update_job));
self.pending_refresh = Some(PendingRefresh {
job: pending_job,
handle,
});
Ok(())
}
fn drain_pending(&mut self) {
if self.relay_rdma {
return;
}
if let Some(pending) = self.pending_refresh.take() {
let _ = join_pending_refresh(self.session_id.as_str(), pending);
}
}
fn read(&mut self, slot_index: usize, length: usize) -> Result<usize, String> {
self.transport
.rdma_read_pipelined(
&self.local_slots[slot_index].registered,
length,
&self.remote,
RDMA_READ_PIPELINE_CHUNK_BYTES,
RDMA_READ_PIPELINE_DEPTH,
)
.map_err(|err| format!("rdma read {}: {err}", self.relative_path))?;
Ok(length)
}
fn local_slot(&self, slot_index: usize) -> Arc<LocalReadSlot> {
Arc::clone(&self.local_slots[slot_index])
}
}
fn prepare_pending_write(
reader: &mut ReusableRdmaReader,
client: &Client,
job: &TransferJob,
next_job: Option<&TransferJob>,
allow_prefetch: bool,
local_slot: usize,
open_files: &BTreeMap<String, File>,
timeout_seconds: i32,
retry_limit: usize,
_target_temp_path: &str,
) -> Result<PendingWrite, String> {
if !open_files.contains_key(&job.relative_path) {
return Err(format!("target file {} not prepared", job.relative_path));
}
let started = Instant::now();
let read_started = Instant::now();
let (payload, transport) = reader.read_payload(
client,
job,
next_job,
allow_prefetch,
local_slot,
timeout_seconds,
retry_limit,
)?;
let read_ms = read_started.elapsed().as_millis();
let pending_payload = match payload {
PayloadLease::Owned(payload) => PendingPayload::Owned { payload },
PayloadLease::Borrowed { slot_index, length } => PendingPayload::Borrowed {
slot: reader
.slot_handle(slot_index)
.ok_or_else(|| format!("rdma slot {} for {} not available", slot_index, job.relative_path))?,
length,
},
};
Ok(PendingWrite {
job: job.clone(),
payload: pending_payload,
transport,
read_ms,
started,
})
}
fn complete_pending_write(
pending: PendingWrite,
target_file: &File,
enable_chunk_crc32c: bool,
target_temp_path: &str,
capture_relay_payload: bool,
) -> Result<WriteCompletion, String> {
let write_started = Instant::now();
let (payload_len, checksum, crc_ms, relay_payload) = match pending.payload {
PendingPayload::Owned { payload } => {
target_file
.write_at(&payload, pending.job.offset as u64)
.map_err(|err| format!("write chunk {}@{}: {err}", pending.job.relative_path, pending.job.offset))?;
let payload_len = payload.len();
let (checksum, crc_ms) = if enable_chunk_crc32c {
let crc_started = Instant::now();
let checksum = crc32c::crc32c(&payload);
(checksum, crc_started.elapsed().as_millis())
} else {
(0, 0)
};
let relay_payload = if capture_relay_payload { Some(payload) } else { None };
(payload_len, checksum, crc_ms, relay_payload)
}
PendingPayload::Borrowed { slot, length } => {
let payload = slot.slice(length);
target_file
.write_at(payload, pending.job.offset as u64)
.map_err(|err| format!("write chunk {}@{}: {err}", pending.job.relative_path, pending.job.offset))?;
let (checksum, crc_ms) = if enable_chunk_crc32c {
let crc_started = Instant::now();
let checksum = crc32c::crc32c(payload);
(checksum, crc_started.elapsed().as_millis())
} else {
(0, 0)
};
let relay_payload = if capture_relay_payload {
Some(payload.to_vec())
} else {
None
};
(payload.len(), checksum, crc_ms, relay_payload)
}
};
let write_ms = write_started.elapsed().as_millis();
let read_ms = pending.read_ms;
let _crc_ms = crc_ms;
let total_ms = pending.started.elapsed().as_millis();
let chunk = TransferredChunk {
chunk_id: format!("{}:{}:{}", pending.job.source_id, pending.job.relative_path, pending.job.offset),
file_path: normalize_path(path_join(target_temp_path, &pending.job.relative_path)),
relative_path: pending.job.relative_path,
offset: pending.job.offset,
size: payload_len as i64,
crc32c: if enable_chunk_crc32c {
encode_crc32c(checksum)
} else {
String::new()
},
source_id: pending.job.source_id,
};
Ok(WriteCompletion {
outcome: JobOutcome {
chunk: chunk.clone(),
transport: pending.transport,
read_ms,
write_ms,
total_ms,
},
relay_chunk: relay_payload.map(|payload| RelayPublishedChunk {
chunk,
transport: pending.transport,
payload,
}),
})
}
fn spawn_pending_write(
pending: PendingWrite,
target_file: File,
enable_chunk_crc32c: bool,
target_temp_path: String,
capture_relay_payload: bool,
) -> JoinHandle<Result<WriteCompletion, String>> {
std::thread::spawn(move || {
complete_pending_write(
pending,
&target_file,
enable_chunk_crc32c,
&target_temp_path,
capture_relay_payload,
)
})
}
fn push_outcome(
completion: WriteCompletion,
outcomes: &mut Vec<JobOutcome>,
relay_batch: &mut Vec<RelayPublishedChunk>,
relay_batch_bytes: &mut i64,
relay_publisher: Option<&RelayBatchPublisher>,
) -> Result<(), String> {
if let Some(relay_chunk) = completion.relay_chunk {
*relay_batch_bytes += relay_chunk.chunk.size;
relay_batch.push(relay_chunk);
let publish_threshold = relay_publisher
.map(|publisher| publisher.batch_bytes())
.unwrap_or(RELAY_PUBLISH_BATCH_BYTES);
if *relay_batch_bytes >= publish_threshold {
if let Some(publisher) = relay_publisher {
let publish_bytes = *relay_batch_bytes;
let publish_chunks = relay_batch.len();
let publish_started = Instant::now();
publisher.publish(std::mem::take(relay_batch))?;
eprintln!(
"relay publish hot-path queued bytes={} chunks={} blocking_ms={}",
publish_bytes,
publish_chunks,
publish_started.elapsed().as_millis()
);
}
relay_batch.clear();
*relay_batch_bytes = 0;
}
}
outcomes.push(completion.outcome);
Ok(())
}
fn receive_write_completion(
completions: &Receiver<Result<WriteCompletion, String>>,
outcomes: &mut Vec<JobOutcome>,
relay_batch: &mut Vec<RelayPublishedChunk>,
relay_batch_bytes: &mut i64,
relay_publisher: Option<&RelayBatchPublisher>,
) -> Result<(), String> {
let completion = completions
.recv()
.map_err(|_| "writer completion channel closed".to_string())??;
push_outcome(
completion,
outcomes,
relay_batch,
relay_batch_bytes,
relay_publisher,
)
}
fn shared_write_queues() -> &'static Mutex<HashMap<String, Arc<SharedWriteQueue>>> {
SHARED_WRITE_QUEUES.get_or_init(|| Mutex::new(HashMap::new()))
}
fn spawn_serial_write_worker(
open_files: Arc<BTreeMap<String, File>>,
enable_chunk_crc32c: bool,
target_temp_path: String,
capture_relay_payload: bool,
pending_rx: Receiver<PendingWrite>,
completion_tx: SyncSender<Result<WriteCompletion, String>>,
) -> JoinHandle<Result<(), String>> {
std::thread::spawn(move || {
while let Ok(pending) = pending_rx.recv() {
let writer_file = open_files
.get(&pending.job.relative_path)
.ok_or_else(|| format!("target file {} not prepared", pending.job.relative_path))?
.try_clone()
.map_err(|err| format!("clone target file {}: {err}", pending.job.relative_path))?;
let completion = complete_pending_write(
pending,
&writer_file,
enable_chunk_crc32c,
&target_temp_path,
capture_relay_payload,
);
completion_tx
.send(completion)
.map_err(|_| "writer completion receiver dropped".to_string())?;
}
Ok(())
})
}
fn spawn_shared_write_worker(
open_files: Arc<BTreeMap<String, File>>,
target_temp_path: String,
pending_rx: Receiver<SharedWriteTask>,
) {
let _ = std::thread::spawn(move || {
while let Ok(task) = pending_rx.recv() {
let writer_file = match open_files.get(&task.pending.job.relative_path) {
Some(file) => match file.try_clone() {
Ok(clone) => clone,
Err(err) => {
let _ = task.completion_tx.send(Err(format!(
"clone target file {}: {err}",
task.pending.job.relative_path
)));
continue;
}
},
None => {
let _ = task.completion_tx.send(Err(format!(
"target file {} not prepared",
task.pending.job.relative_path
)));
continue;
}
};
let completion = complete_pending_write(
task.pending,
&writer_file,
task.enable_chunk_crc32c,
&target_temp_path,
task.capture_relay_payload,
);
let _ = task.completion_tx.send(completion);
}
});
}
fn get_or_create_shared_write_queue(
target_temp_path: &str,
open_files: Arc<BTreeMap<String, File>>,
queue_depth: usize,
) -> Result<Arc<SharedWriteQueue>, String> {
let mut queues = shared_write_queues()
.lock()
.map_err(|_| "lock shared write queues".to_string())?;
if let Some(queue) = queues.get(target_temp_path) {
queue.active_users.fetch_add(1, Ordering::Relaxed);
return Ok(Arc::clone(queue));
}
let (pending_tx, pending_rx) = mpsc::sync_channel::<SharedWriteTask>(queue_depth);
let queue = Arc::new(SharedWriteQueue {
pending_tx,
active_users: AtomicUsize::new(1),
});
spawn_shared_write_worker(open_files, target_temp_path.to_string(), pending_rx);
queues.insert(target_temp_path.to_string(), Arc::clone(&queue));
Ok(queue)
}
fn release_shared_write_queue(target_temp_path: &str, queue: Arc<SharedWriteQueue>) {
if queue.active_users.fetch_sub(1, Ordering::AcqRel) != 1 {
return;
}
if let Ok(mut queues) = shared_write_queues().lock() {
if let Some(current) = queues.get(target_temp_path) {
if Arc::ptr_eq(current, &queue) {
queues.remove(target_temp_path);
}
}
}
}
fn execute_job_group_single_writer(
group: Vec<TransferJob>,
client: Client,
open_files: Arc<BTreeMap<String, File>>,
read_timeout: i32,
retry_limit: usize,
enable_chunk_crc32c: bool,
target_temp_path: String,
relay_publisher: Option<RelayBatchPublisher>,
queue_depth: usize,
force_sync_close: bool,
) -> Result<Vec<JobOutcome>, String> {
if group.is_empty() {
return Ok(Vec::new());
}
let mut reader = ReusableRdmaReader::new();
let mut outcomes = Vec::with_capacity(group.len());
let mut relay_batch: Vec<RelayPublishedChunk> = Vec::new();
let mut relay_batch_bytes = 0_i64;
let capture_relay_payload = relay_publisher.is_some();
let allow_prefetch = relay_publisher.is_none();
let (pending_tx, pending_rx) = mpsc::sync_channel::<PendingWrite>(queue_depth);
let (completion_tx, completion_rx) = mpsc::sync_channel::<Result<WriteCompletion, String>>(queue_depth);
let writer = spawn_serial_write_worker(
Arc::clone(&open_files),
enable_chunk_crc32c,
target_temp_path.clone(),
capture_relay_payload,
pending_rx,
completion_tx,
);
let mut inflight = 0_usize;
let mut pending = prepare_pending_write(
&mut reader,
&client,
&group[0],
group.get(1),
allow_prefetch,
0,
open_files.as_ref(),
read_timeout,
retry_limit,
&target_temp_path,
)?;
for index in 1..group.len() {
pending_tx
.send(pending)
.map_err(|_| "serial writer queue closed".to_string())?;
inflight += 1;
if inflight >= queue_depth {
receive_write_completion(
&completion_rx,
&mut outcomes,
&mut relay_batch,
&mut relay_batch_bytes,
relay_publisher.as_ref(),
)?;
inflight -= 1;
}
pending = prepare_pending_write(
&mut reader,
&client,
&group[index],
group.get(index + 1),
allow_prefetch,
index % RDMA_LOCAL_SLOT_COUNT,
open_files.as_ref(),
read_timeout,
retry_limit,
&target_temp_path,
)?;
}
pending_tx
.send(pending)
.map_err(|_| "serial writer queue closed".to_string())?;
inflight += 1;
drop(pending_tx);
while inflight > 0 {
receive_write_completion(
&completion_rx,
&mut outcomes,
&mut relay_batch,
&mut relay_batch_bytes,
relay_publisher.as_ref(),
)?;
inflight -= 1;
}
writer
.join()
.map_err(|_| "serial writer worker panicked".to_string())??;
if let Some(publisher) = relay_publisher.as_ref() {
if !relay_batch.is_empty() {
let publish_bytes = relay_batch.iter().map(|chunk| chunk.chunk.size).sum::<i64>();
let publish_chunks = relay_batch.len();
let publish_started = Instant::now();
publisher.publish(std::mem::take(&mut relay_batch))?;
eprintln!(
"relay publish group-final queued bytes={} chunks={} blocking_ms={}",
publish_bytes,
publish_chunks,
publish_started.elapsed().as_millis()
);
}
}
let bytes = outcomes.iter().map(|outcome| outcome.chunk.size).sum::<i64>();
let read_ms = outcomes.iter().map(|outcome| outcome.read_ms).sum::<u128>();
let write_ms = outcomes.iter().map(|outcome| outcome.write_ms).sum::<u128>();
let total_ms = outcomes.iter().map(|outcome| outcome.total_ms).sum::<u128>();
let source_id = outcomes
.first()
.map(|outcome| outcome.chunk.source_id.as_str())
.unwrap_or("unknown");
let transport = outcomes
.iter()
.all(|outcome| outcome.transport == PayloadTransport::Rdma);
eprintln!(
"direct job-group serial-writer source={} jobs={} bytes={} read_ms={} write_ms={} total_ms={} all_rdma={} queue_depth={}",
source_id,
outcomes.len(),
bytes,
read_ms,
write_ms,
total_ms,
transport,
queue_depth
);
if group.iter().all(|job| job.source.relay_rdma.is_some()) || force_sync_close {
eprintln!("direct job-group serial-writer closing reader sync source={}", source_id);
reader.close(&client);
eprintln!("direct job-group serial-writer reader sync-closed source={}", source_id);
} else {
eprintln!("direct job-group serial-writer closing reader background source={}", source_id);
reader.close_background(client.clone());
eprintln!("direct job-group serial-writer reader background-started source={}", source_id);
}
Ok(outcomes)
}
fn execute_job_group_shared_writer(
group: Vec<TransferJob>,
client: Client,
open_files: Arc<BTreeMap<String, File>>,
read_timeout: i32,
retry_limit: usize,
enable_chunk_crc32c: bool,
target_temp_path: String,
shared_queue_key: String,
relay_publisher: Option<RelayBatchPublisher>,
queue_depth: usize,
force_sync_close: bool,
) -> Result<Vec<JobOutcome>, String> {
if group.is_empty() {
return Ok(Vec::new());
}
let mut reader = ReusableRdmaReader::new();
let mut outcomes = Vec::with_capacity(group.len());
let mut relay_batch: Vec<RelayPublishedChunk> = Vec::new();
let mut relay_batch_bytes = 0_i64;
let capture_relay_payload = relay_publisher.is_some();
let allow_prefetch = relay_publisher.is_none();
let shared_queue =
get_or_create_shared_write_queue(&shared_queue_key, Arc::clone(&open_files), queue_depth)?;
let result = (|| -> Result<Vec<JobOutcome>, String> {
let mut completion_receivers: VecDeque<Receiver<Result<WriteCompletion, String>>> = VecDeque::new();
let mut pending = prepare_pending_write(
&mut reader,
&client,
&group[0],
group.get(1),
allow_prefetch,
0,
open_files.as_ref(),
read_timeout,
retry_limit,
&target_temp_path,
)?;
for index in 1..group.len() {
let (completion_tx, completion_rx) = mpsc::sync_channel::<Result<WriteCompletion, String>>(1);
shared_queue
.pending_tx
.send(SharedWriteTask {
pending,
enable_chunk_crc32c,
capture_relay_payload,
completion_tx,
})
.map_err(|_| "shared writer queue closed".to_string())?;
completion_receivers.push_back(completion_rx);
if completion_receivers.len() >= queue_depth {
let receiver = completion_receivers
.pop_front()
.ok_or_else(|| "shared writer receiver queue empty".to_string())?;
receive_write_completion(
&receiver,
&mut outcomes,
&mut relay_batch,
&mut relay_batch_bytes,
relay_publisher.as_ref(),
)?;
}
pending = prepare_pending_write(
&mut reader,
&client,
&group[index],
group.get(index + 1),
allow_prefetch,
index % RDMA_LOCAL_SLOT_COUNT,
open_files.as_ref(),
read_timeout,
retry_limit,
&target_temp_path,
)?;
}
let (completion_tx, completion_rx) = mpsc::sync_channel::<Result<WriteCompletion, String>>(1);
shared_queue
.pending_tx
.send(SharedWriteTask {
pending,
enable_chunk_crc32c,
capture_relay_payload,
completion_tx,
})
.map_err(|_| "shared writer queue closed".to_string())?;
completion_receivers.push_back(completion_rx);
while let Some(receiver) = completion_receivers.pop_front() {
receive_write_completion(
&receiver,
&mut outcomes,
&mut relay_batch,
&mut relay_batch_bytes,
relay_publisher.as_ref(),
)?;
}
if let Some(publisher) = relay_publisher.as_ref() {
if !relay_batch.is_empty() {
let publish_bytes = relay_batch.iter().map(|chunk| chunk.chunk.size).sum::<i64>();
let publish_chunks = relay_batch.len();
let publish_started = Instant::now();
publisher.publish(std::mem::take(&mut relay_batch))?;
eprintln!(
"relay publish group-final queued bytes={} chunks={} blocking_ms={}",
publish_bytes,
publish_chunks,
publish_started.elapsed().as_millis()
);
}
}
let bytes = outcomes.iter().map(|outcome| outcome.chunk.size).sum::<i64>();
let read_ms = outcomes.iter().map(|outcome| outcome.read_ms).sum::<u128>();
let write_ms = outcomes.iter().map(|outcome| outcome.write_ms).sum::<u128>();
let total_ms = outcomes.iter().map(|outcome| outcome.total_ms).sum::<u128>();
let source_id = outcomes
.first()
.map(|outcome| outcome.chunk.source_id.as_str())
.unwrap_or("unknown");
let transport = outcomes
.iter()
.all(|outcome| outcome.transport == PayloadTransport::Rdma);
eprintln!(
"direct job-group shared-writer source={} jobs={} bytes={} read_ms={} write_ms={} total_ms={} all_rdma={} queue_depth={}",
source_id,
outcomes.len(),
bytes,
read_ms,
write_ms,
total_ms,
transport,
queue_depth
);
if group.iter().all(|job| job.source.relay_rdma.is_some()) || force_sync_close {
eprintln!("direct job-group shared-writer closing reader sync source={}", source_id);
reader.close(&client);
eprintln!("direct job-group shared-writer reader sync-closed source={}", source_id);
} else {
eprintln!("direct job-group shared-writer closing reader background source={}", source_id);
reader.close_background(client.clone());
eprintln!("direct job-group shared-writer reader background-started source={}", source_id);
}
Ok(outcomes)
})();
release_shared_write_queue(&shared_queue_key, shared_queue);
result
}
fn execute_job_group(
group: Vec<TransferJob>,
client: Client,
open_files: Arc<BTreeMap<String, File>>,
read_timeout: i32,
retry_limit: usize,
enable_chunk_crc32c: bool,
target_temp_path: String,
shared_queue_key: String,
relay_publisher: Option<RelayBatchPublisher>,
force_sync_close: bool,
) -> Result<Vec<JobOutcome>, String> {
if group.is_empty() {
return Ok(Vec::new());
}
let relay_rdma_group = group_is_relay_rdma(&group);
let relay_publish_group = relay_publisher.is_some();
let large_single_source_group =
!relay_rdma_group && !relay_publish_group && group_total_bytes(&group) >= LARGE_SINGLE_SOURCE_GROUP_BYTES;
let shared_writer_queue_depth = if relay_rdma_group || relay_publish_group || large_single_source_group {
0
} else {
rdma_shared_writer_queue_depth()
};
if shared_writer_queue_depth > 0 {
return execute_job_group_shared_writer(
group,
client,
open_files,
read_timeout,
retry_limit,
enable_chunk_crc32c,
target_temp_path,
shared_queue_key,
relay_publisher,
shared_writer_queue_depth,
force_sync_close,
);
}
let single_writer_queue_depth = if large_single_source_group {
rdma_shared_writer_queue_depth()
.max(rdma_single_writer_queue_depth())
.max(2)
} else if relay_rdma_group || relay_publish_group {
rdma_shared_writer_queue_depth().max(rdma_single_writer_queue_depth())
} else {
rdma_single_writer_queue_depth()
};
if single_writer_queue_depth > 0 {
return execute_job_group_single_writer(
group,
client,
open_files,
read_timeout,
retry_limit,
enable_chunk_crc32c,
target_temp_path,
relay_publisher,
single_writer_queue_depth,
force_sync_close,
);
}
let mut reader = ReusableRdmaReader::new();
let mut outcomes = Vec::with_capacity(group.len());
let mut relay_batch: Vec<RelayPublishedChunk> = Vec::new();
let mut relay_batch_bytes = 0_i64;
let mut writers: Vec<Option<JoinHandle<Result<WriteCompletion, String>>>> =
(0..RDMA_LOCAL_SLOT_COUNT).map(|_| None).collect();
let capture_relay_payload = relay_publisher.is_some();
let allow_prefetch = relay_publisher.is_none();
let mut pending_slot = 0_usize;
let mut pending = prepare_pending_write(
&mut reader,
&client,
&group[0],
group.get(1),
allow_prefetch,
pending_slot,
open_files.as_ref(),
read_timeout,
retry_limit,
&target_temp_path,
)?;
for index in 1..group.len() {
let writer_slot = pending_slot;
let writer_file = open_files
.get(&pending.job.relative_path)
.ok_or_else(|| format!("target file {} not prepared", pending.job.relative_path))?
.try_clone()
.map_err(|err| format!("clone target file {}: {err}", pending.job.relative_path))?;
writers[writer_slot] = Some(spawn_pending_write(
pending,
writer_file,
enable_chunk_crc32c,
target_temp_path.clone(),
capture_relay_payload,
));
let next_slot = index % RDMA_LOCAL_SLOT_COUNT;
if let Some(writer) = writers[next_slot].take() {
let writer_result = writer
.join()
.map_err(|_| format!("write worker panicked for slot {}", next_slot))?;
push_outcome(
writer_result?,
&mut outcomes,
&mut relay_batch,
&mut relay_batch_bytes,
relay_publisher.as_ref(),
)?;
}
let next_pending = prepare_pending_write(
&mut reader,
&client,
&group[index],
group.get(index + 1),
allow_prefetch,
next_slot,
open_files.as_ref(),
read_timeout,
retry_limit,
&target_temp_path,
);
let next_pending = next_pending?;
pending = next_pending;
pending_slot = next_slot;
}
let final_file = open_files
.get(&pending.job.relative_path)
.ok_or_else(|| format!("target file {} not prepared", pending.job.relative_path))?
.try_clone()
.map_err(|err| format!("clone target file {}: {err}", pending.job.relative_path))?;
writers[pending_slot] = Some(spawn_pending_write(
pending,
final_file,
enable_chunk_crc32c,
target_temp_path.clone(),
capture_relay_payload,
));
for (slot_index, writer) in writers.into_iter().enumerate() {
if let Some(writer) = writer {
let writer_result = writer
.join()
.map_err(|_| format!("write worker panicked for slot {}", slot_index))?;
push_outcome(
writer_result?,
&mut outcomes,
&mut relay_batch,
&mut relay_batch_bytes,
relay_publisher.as_ref(),
)?;
}
}
if let Some(publisher) = relay_publisher.as_ref() {
if !relay_batch.is_empty() {
let publish_bytes = relay_batch.iter().map(|chunk| chunk.chunk.size).sum::<i64>();
let publish_chunks = relay_batch.len();
let publish_started = Instant::now();
publisher.publish(std::mem::take(&mut relay_batch))?;
eprintln!(
"relay publish group-final queued bytes={} chunks={} blocking_ms={}",
publish_bytes,
publish_chunks,
publish_started.elapsed().as_millis()
);
}
}
let bytes = outcomes.iter().map(|outcome| outcome.chunk.size).sum::<i64>();
let read_ms = outcomes.iter().map(|outcome| outcome.read_ms).sum::<u128>();
let write_ms = outcomes.iter().map(|outcome| outcome.write_ms).sum::<u128>();
let total_ms = outcomes.iter().map(|outcome| outcome.total_ms).sum::<u128>();
let source_id = outcomes
.first()
.map(|outcome| outcome.chunk.source_id.as_str())
.unwrap_or("unknown");
let transport = outcomes
.iter()
.all(|outcome| outcome.transport == PayloadTransport::Rdma);
eprintln!(
"direct job-group source={} jobs={} bytes={} read_ms={} write_ms={} total_ms={} all_rdma={}",
source_id,
outcomes.len(),
bytes,
read_ms,
write_ms,
total_ms,
transport
);
if group.iter().all(|job| job.source.relay_rdma.is_some()) || force_sync_close {
reader.close(&client);
} else {
reader.close_background(client.clone());
}
Ok(outcomes)
}
fn desired_worker_count(spec: &TransferSpec, job_count: usize, source_count: usize) -> usize {
if job_count == 0 {
return 1;
}
let requested = if spec.parallelism > 0 {
spec.parallelism as usize
} else {
8
};
if is_ring_relay_collective(spec) {
return requested.max(1).min(job_count).min(8);
}
if resolve_relay_root_publisher(spec).is_some() {
return job_count.min(2);
}
if spec.transfer_mode == "DIRECT_STRIPED" {
if source_count <= 1 {
return striped_single_source_worker_budget(spec).min(job_count);
}
return source_count.max(1).min(job_count).min(8);
}
requested.max(1).min(job_count).min(8)
}
fn group_is_relay_rdma(group: &[TransferJob]) -> bool {
!group.is_empty() && group.iter().all(|job| job.source.relay_rdma.is_some())
}
fn group_total_bytes(group: &[TransferJob]) -> i64 {
group.iter().map(|job| job.length).sum()
}
fn build_job_groups(
transfer_mode: &str,
jobs: Vec<TransferJob>,
worker_count: usize,
source_count: usize,
) -> Vec<Vec<TransferJob>> {
if transfer_mode == "DIRECT_STRIPED" && source_count > 1 {
split_jobs_by_source(jobs, worker_count)
} else {
split_jobs_contiguous(jobs, worker_count)
}
}
fn split_jobs_contiguous(jobs: Vec<TransferJob>, worker_count: usize) -> Vec<Vec<TransferJob>> {
if jobs.is_empty() || worker_count <= 1 {
return vec![jobs];
}
let chunk_len = jobs.len().div_ceil(worker_count);
let mut groups = Vec::new();
let mut iter = jobs.into_iter();
loop {
let mut group = Vec::with_capacity(chunk_len);
for _ in 0..chunk_len {
let Some(job) = iter.next() else {
break;
};
group.push(job);
}
if group.is_empty() {
break;
}
groups.push(group);
}
groups
}
fn split_jobs_by_source(jobs: Vec<TransferJob>, worker_count: usize) -> Vec<Vec<TransferJob>> {
if jobs.is_empty() || worker_count <= 1 {
return vec![jobs];
}
let mut groups = Vec::new();
let mut current = Vec::new();
let mut current_key: Option<(String, String, String)> = None;
for job in jobs {
let next_key = (
job.source_id.clone(),
job.source.endpoint.clone(),
job.source.path.clone(),
);
match ¤t_key {
Some(key) if key != &next_key => {
groups.push(current);
current = vec![job];
current_key = Some(next_key);
}
Some(_) => current.push(job),
None => {
current.push(job);
current_key = Some(next_key);
}
}
}
if !current.is_empty() {
groups.push(current);
}
if groups.len() > worker_count {
return split_jobs_contiguous(
groups.into_iter().flatten().collect::<Vec<_>>(),
worker_count,
);
}
groups
}
fn same_job(left: &TransferJob, right: &TransferJob) -> bool {
left.source_id == right.source_id
&& left.source.endpoint == right.source.endpoint
&& left.source.path == right.source.path
&& left.relative_path == right.relative_path
&& left.offset == right.offset
&& left.source_offset == right.source_offset
&& left.length == right.length
}
fn append_full_file_jobs(
target: &mut Vec<TransferJob>,
segment: &SourceSegmentPlan,
files: &[ArtifactFile],
chunk_size_bytes: i64,
) {
for file in files {
if file.size_bytes <= 0 {
continue;
}
let mut offset = 0_i64;
let chunk_size = if chunk_size_bytes <= 0 || chunk_size_bytes > file.size_bytes {
file.size_bytes
} else {
chunk_size_bytes
};
while offset < file.size_bytes {
let length = std::cmp::min(chunk_size, file.size_bytes - offset);
target.push(TransferJob {
source_id: segment.source_id.clone(),
source: segment.source_endpoint.clone(),
relative_path: file.relative_path.clone(),
offset,
length,
source_offset: offset,
});
offset += length;
}
}
}
fn append_range_jobs(target: &mut Vec<TransferJob>, segment: &SourceSegmentPlan) {
for range in &segment.byte_ranges {
if range.end <= range.start {
continue;
}
target.push(TransferJob {
source_id: segment.source_id.clone(),
source: segment.source_endpoint.clone(),
relative_path: range.relative_path.clone(),
offset: range.start,
length: range.end - range.start,
source_offset: if segment.source_endpoint.relay_rdma.is_some() {
range.relay_offset
} else {
range.start
},
});
}
}
fn coalesce_jobs(jobs: Vec<TransferJob>, max_bytes: i64) -> Vec<TransferJob> {
if jobs.len() <= 1 || max_bytes <= 0 {
return jobs;
}
let mut merged: Vec<TransferJob> = Vec::with_capacity(jobs.len());
for job in jobs {
if let Some(previous) = merged.last_mut() {
let same_source = previous.source_id == job.source_id
&& previous.relative_path == job.relative_path
&& previous.source.endpoint == job.source.endpoint
&& previous.source.path == job.source.path;
let contiguous = previous.offset + previous.length == job.offset
&& previous.source_offset + previous.length == job.source_offset;
let within_budget = previous.length + job.length <= max_bytes;
if same_source && contiguous && within_budget {
previous.length += job.length;
continue;
}
}
merged.push(job);
}
merged
}
fn prepare_target_files(root: &str, files: &[ArtifactFile], preserve_existing: bool) -> Result<BTreeMap<String, File>, String> {
let root_path = PathBuf::from(root);
if root_path.exists() && !preserve_existing {
fs::remove_dir_all(&root_path)
.map_err(|err| format!("reset target temp path {}: {err}", root))?;
}
fs::create_dir_all(&root_path).map_err(|err| format!("create target temp root {}: {err}", root))?;
let mut open_files = BTreeMap::new();
for file in files {
let target_path = root_path.join(Path::new(&file.relative_path));
if let Some(parent) = target_path.parent() {
fs::create_dir_all(parent)
.map_err(|err| format!("create target directory {}: {err}", parent.display()))?;
}
let handle = OpenOptions::new()
.create(true)
.read(true)
.write(true)
.open(&target_path)
.map_err(|err| format!("open target file {}: {err}", target_path.display()))?;
let current_len = handle
.metadata()
.map_err(|err| format!("stat target file {}: {err}", target_path.display()))?
.len();
if !preserve_existing || current_len != file.size_bytes as u64 {
handle
.set_len(file.size_bytes as u64)
.map_err(|err| format!("resize target file {}: {err}", target_path.display()))?;
}
open_files.insert(file.relative_path.clone(), handle);
}
Ok(open_files)
}
fn stat_source(segment: &SourceSegmentPlan, relative_path: &str) -> Result<i64, String> {
if segment.source_endpoint.endpoint.is_empty() {
let source_path = resolve_source_path(&segment.source_endpoint.path, relative_path);
let metadata =
fs::metadata(&source_path).map_err(|err| format!("stat local source {}: {err}", source_path.display()))?;
return Ok(metadata.len() as i64);
}
let client = Client::builder()
.timeout(Duration::from_secs(30))
.build()
.map_err(|err| format!("build HTTP client for stat: {err}"))?;
let response = client
.post(format!("{}/v1/exports/stat", segment.source_endpoint.endpoint))
.json(&StatExportRequest {
root_path: &segment.source_endpoint.path,
relative_path,
})
.send()
.map_err(|err| format!("request stat export: {err}"))?;
if !response.status().is_success() {
return Err(format!("stat export failed with status {}", response.status()));
}
let body: StatExportResponse = response
.json()
.map_err(|err| format!("decode stat export response: {err}"))?;
Ok(body.size_bytes)
}
fn read_payload_via_http(client: &Client, job: &TransferJob) -> Result<(Vec<u8>, Option<String>), String> {
let request = ReadChunkRequest {
root_path: &job.source.path,
relative_path: &job.relative_path,
offset: job.offset,
length: job.length,
};
let raw_response = client
.post(format!("{}/v1/exports/readat/raw", job.source.endpoint))
.json(&request)
.send()
.map_err(|err| format!("request export raw chunk {}@{}: {err}", job.relative_path, job.offset))?;
if raw_response.status().is_success() {
let expected_crc32c = raw_response
.headers()
.get("X-Chunk-CRC32C")
.and_then(|v| v.to_str().ok())
.map(|s| s.trim().to_string())
.filter(|s| !s.is_empty());
let payload = raw_response
.bytes()
.map_err(|err| format!("decode raw chunk response {}@{}: {err}", job.relative_path, job.offset))?
.to_vec();
if payload.len() != job.length as usize {
return Err(format!(
"read export raw short-read for {}@{}: expected {} bytes, got {}",
job.relative_path,
job.offset,
job.length,
payload.len()
));
}
return Ok((payload, expected_crc32c));
}
let response = client
.post(format!("{}/v1/exports/readat", job.source.endpoint))
.json(&ReadChunkRequest {
root_path: &job.source.path,
relative_path: &job.relative_path,
offset: job.offset,
length: job.length,
})
.send()
.map_err(|err| format!("request export chunk {}@{}: {err}", job.relative_path, job.offset))?;
if !response.status().is_success() {
return Err(format!(
"read export failed for {}@{} with status {}",
job.relative_path,
job.offset,
response.status()
));
}
let body: ReadChunkResponse = response
.json()
.map_err(|err| format!("decode chunk response {}@{}: {err}", job.relative_path, job.offset))?;
let payload = BASE64_STANDARD
.decode(body.data.as_bytes())
.map_err(|err| format!("decode chunk payload {}@{}: {err}", job.relative_path, job.offset))?;
if payload.len() != job.length as usize {
return Err(format!(
"read export short-read for {}@{}: expected {} bytes, got {}",
job.relative_path,
job.offset,
job.length,
payload.len()
));
}
Ok((payload, if body.crc32c.is_empty() { None } else { Some(body.crc32c) }))
}
fn close_rdma_session(client: &Client, endpoint: &str, session_id: &str) -> Result<(), String> {
let response = client
.post(format!("{}/v1/exports/rdma/close", endpoint))
.json(&CloseRDMAExportRequest {
session_id: session_id.to_string(),
})
.send()
.map_err(|err| format!("close rdma export session {}: {err}", session_id))?;
if !response.status().is_success() {
return Err(format!(
"close rdma export session {} failed with status {}",
session_id,
response.status()
));
}
let body: CloseRDMAExportResponseOwned = response
.json()
.map_err(|err| format!("decode close rdma export session {} response: {err}", session_id))?;
if !body.closed {
return Err(format!("close rdma export session {} did not release source session", session_id));
}
Ok(())
}
fn read_local_chunk(root_path: &str, relative_path: &str, offset: i64, length: i64) -> Result<Vec<u8>, String> {
let source_path = resolve_source_path_checked(root_path, relative_path)?;
let mut buffer = vec![0_u8; length as usize];
load_source_range(&source_path, offset, length, &mut buffer)?;
Ok(buffer)
}
fn load_source_range(source_path: &Path, offset: i64, length: i64, buffer: &mut [u8]) -> Result<(), String> {
let metadata = fs::metadata(source_path)
.map_err(|err| format!("stat source file {}: {err}", source_path.display()))?;
if offset as u64 > metadata.len() || offset as u64 + length as u64 > metadata.len() {
return Err(format!(
"requested range [{} , {}) exceeds file size {}",
offset,
offset + length,
metadata.len()
));
}
let file = File::open(source_path).map_err(|err| format!("open source file {}: {err}", source_path.display()))?;
load_source_range_from_file(&file, source_path, offset, length, buffer)
}
fn load_source_range_from_file(
file: &File,
source_path: &Path,
offset: i64,
length: i64,
buffer: &mut [u8],
) -> Result<(), String> {
let mut filled = 0_usize;
while filled < buffer.len() {
let read = file
.read_at(&mut buffer[filled..], offset as u64 + filled as u64)
.map_err(|err| format!("read source file {}@{}: {err}", source_path.display(), offset + filled as i64))?;
if read == 0 {
return Err(format!(
"read source file {}@{} reached EOF before {} bytes",
source_path.display(),
offset,
length
));
}
filled += read;
}
Ok(())
}
fn resolve_source_path(root_path: &str, relative_path: &str) -> PathBuf {
if relative_path.is_empty() {
PathBuf::from(root_path)
} else {
PathBuf::from(root_path).join(Path::new(relative_path))
}
}
fn resolve_source_path_checked(root_path: &str, relative_path: &str) -> Result<PathBuf, String> {
let clean_root = PathBuf::from(root_path);
let relative = Path::new(relative_path);
if relative.is_absolute() {
return Err(format!("relative path {} must not be absolute", relative.display()));
}
for component in relative.components() {
match component {
Component::ParentDir | Component::RootDir | Component::Prefix(_) => {
return Err(format!(
"relative path {} escapes root {}",
relative.display(),
clean_root.display()
));
}
Component::CurDir | Component::Normal(_) => {}
}
}
let joined = resolve_source_path(root_path, relative_path);
if !joined.starts_with(&clean_root) {
return Err(format!(
"source path {} escapes root {}",
joined.display(),
clean_root.display()
));
}
fs::metadata(&joined).map_err(|err| format!("stat source path {}: {err}", joined.display()))?;
Ok(joined)
}
fn path_join(root: &str, relative_path: &str) -> PathBuf {
PathBuf::from(root).join(Path::new(relative_path))
}
fn normalize_path(path: PathBuf) -> String {
path.to_string_lossy().replace('\\', "/")
}
fn encode_crc32c(sum: u32) -> String {
format!("{sum:08x}")
}
fn timeout_seconds(spec: &TransferSpec) -> u64 {
if spec.timeout_seconds <= 0 {
30
} else {
spec.timeout_seconds as u64
}
}
fn throughput(bytes: i64, elapsed: Duration) -> f64 {
if bytes <= 0 || elapsed.is_zero() {
return 0.0;
}
bytes as f64 / 1024_f64 / 1024_f64 / elapsed.as_secs_f64()
}
fn now_millis() -> i64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|duration| duration.as_millis() as i64)
.unwrap_or_default()
}
fn dedupe_chunks(chunks: Vec<TransferredChunk>) -> Vec<TransferredChunk> {
let mut deduped = BTreeMap::<String, TransferredChunk>::new();
for chunk in chunks {
deduped.insert(chunk.chunk_id.clone(), chunk);
}
deduped.into_values().collect()
}
pub fn open_export_json(request: &str) -> Result<String, String> {
let req: OpenRDMAExportRequestOwned =
serde_json::from_str(request).map_err(|err| format!("parse open export request: {err}"))?;
if req.offset < 0 || req.length <= 0 {
return Err("offset must be non-negative and length must be positive".to_string());
}
let _started = Instant::now();
let source_path = resolve_source_path_checked(&req.root_path, &req.relative_path)?;
let source_file = File::open(&source_path).map_err(|err| format!("open source file {}: {err}", source_path.display()))?;
let mut primary = vec![0_u8; req.length as usize];
let mut standby = vec![0_u8; req.length as usize];
let load_started = Instant::now();
load_source_range_from_file(&source_file, &source_path, req.offset, req.length, &mut primary)?;
let load_ms = load_started.elapsed().as_millis();
let transport_started = Instant::now();
let transport = RdmaTransport::open_default().map_err(|err| format!("open source rdma transport: {err}"))?;
let transport_ms = transport_started.elapsed().as_millis();
let register_started = Instant::now();
let primary_registered = transport
.register_memory(&mut primary)
.map_err(|err| format!("register source rdma memory: {err}"))?;
let standby_registered = transport
.register_memory(&mut standby)
.map_err(|err| format!("register standby source rdma memory: {err}"))?;
let register_ms = register_started.elapsed().as_millis();
let endpoint_started = Instant::now();
let primary_endpoint = transport
.local_endpoint(&primary_registered)
.map_err(|err| format!("query source rdma endpoint: {err}"))?;
let standby_endpoint = transport
.local_endpoint(&standby_registered)
.map_err(|err| format!("query standby source rdma endpoint: {err}"))?;
let endpoint_ms = endpoint_started.elapsed().as_millis();
let session_id = format!("rdma-{}-{}", now_millis(), SESSION_COUNTER.fetch_add(1, Ordering::Relaxed));
let expected_crc32c = encode_crc32c(crc32c::crc32c(&primary));
let session = ExportSession {
source_path: Some(source_path),
source_file: Some(source_file),
transport,
slots: [
ExportSlot {
buffer: primary,
registered: primary_registered,
endpoint: primary_endpoint,
},
ExportSlot {
buffer: standby,
registered: standby_registered,
endpoint: standby_endpoint,
},
],
active_slot: 0,
mutable: true,
};
export_sessions()
.lock()
.map_err(|_| "lock export sessions".to_string())?
.insert(session_id.clone(), Arc::new(Mutex::new(session)));
let response = OpenRDMAExportResponse {
session_id: session_id.clone(),
transport_path: "RDMA".to_string(),
connection_info: rdma_endpoint_to_wire("", &primary_endpoint),
expected_crc32c,
};
let total_ms = _started.elapsed().as_millis();
eprintln!(
"rdma export open session={} relative_path={} offset={} length={} load_ms={} transport_ms={} register_ms={} endpoint_ms={} total_ms={}",
session_id,
req.relative_path,
req.offset,
req.length,
load_ms,
transport_ms,
register_ms,
endpoint_ms,
total_ms
);
serde_json::to_string(&response).map_err(|err| format!("serialize open export response: {err}"))
}
pub fn update_export_json(request: &str) -> Result<String, String> {
let req: UpdateRDMAExportRequestOwned =
serde_json::from_str(request).map_err(|err| format!("parse update export request: {err}"))?;
if req.offset < 0 || req.length <= 0 {
return Err("offset must be non-negative and length must be positive".to_string());
}
let _started = Instant::now();
let session = {
let sessions = export_sessions().lock().map_err(|_| "lock export sessions".to_string())?;
sessions
.get(&req.session_id)
.cloned()
.ok_or_else(|| format!("rdma export session {} not found", req.session_id))?
};
let mut session = session
.lock()
.map_err(|_| format!("lock export session {}", req.session_id))?;
if !session.mutable {
return Err(format!("rdma export session {} does not support update", req.session_id));
}
let source_file = session
.source_file
.as_ref()
.ok_or_else(|| format!("source file missing for session {}", req.session_id))?
.try_clone()
.map_err(|err| format!("clone source file for session {}: {err}", req.session_id))?;
let source_path = session
.source_path
.clone()
.ok_or_else(|| format!("source path missing for session {}", req.session_id))?;
let next_slot = 1 - session.active_slot;
let load_started = Instant::now();
{
let slot = &mut session.slots[next_slot];
if req.length as usize > slot.buffer.len() {
return Err(format!(
"requested length {} exceeds session capacity {}",
req.length,
slot.buffer.len()
));
}
load_source_range_from_file(
&source_file,
&source_path,
req.offset,
req.length,
&mut slot.buffer[..req.length as usize],
)?;
let _registered = &slot.registered;
slot.endpoint.length = req.length as u64;
}
session.active_slot = next_slot;
let load_ms = load_started.elapsed().as_millis();
let total_ms = _started.elapsed().as_millis();
let response = UpdateRDMAExportResponse {
session_id: req.session_id,
transport_path: "RDMA".to_string(),
connection_info: rdma_endpoint_to_wire("", &session.slots[next_slot].endpoint),
expected_crc32c: encode_crc32c(crc32c::crc32c(&session.slots[next_slot].buffer[..req.length as usize])),
};
eprintln!(
"rdma export update session={} offset={} length={} slot={} load_ms={} total_ms={}",
response.session_id,
req.offset,
req.length,
next_slot,
load_ms,
total_ms
);
serde_json::to_string(&response).map_err(|err| format!("serialize update export response: {err}"))
}
pub fn connect_export_json(request: &str) -> Result<String, String> {
let req: ConnectRDMAExportRequestOwned =
serde_json::from_str(request).map_err(|err| format!("parse connect export request: {err}"))?;
let started = Instant::now();
let remote = rdma_endpoint_from_wire(&req.connection_info)?;
let session = {
let sessions = export_sessions().lock().map_err(|_| "lock export sessions".to_string())?;
sessions
.get(&req.session_id)
.cloned()
.ok_or_else(|| format!("rdma export session {} not found", req.session_id))?
};
let session = session
.lock()
.map_err(|_| format!("lock export session {}", req.session_id))?;
session
.transport
.connect_rc(&remote)
.map_err(|err| format!("connect source qp for session {}: {err}", req.session_id))?;
let response = ConnectRDMAExportResponse {
session_id: req.session_id,
transport_path: "RDMA".to_string(),
};
eprintln!(
"rdma export connect session={} total_ms={}",
response.session_id,
started.elapsed().as_millis()
);
serde_json::to_string(&response).map_err(|err| format!("serialize connect export response: {err}"))
}
pub fn close_export_json(request: &str) -> Result<String, String> {
let req: CloseRDMAExportRequestOwned =
serde_json::from_str(request).map_err(|err| format!("parse close export request: {err}"))?;
let removed = {
export_sessions()
.lock()
.map_err(|_| "lock export sessions".to_string())?
.remove(&req.session_id)
};
let response = CloseRDMAExportResponseOwned {
session_id: req.session_id,
closed: removed.is_some(),
};
serde_json::to_string(&response).map_err(|err| format!("serialize close export response: {err}"))
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
struct OpenRDMAExportRequestOwned {
root_path: String,
relative_path: String,
offset: i64,
length: i64,
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
struct ConnectRDMAExportRequestOwned {
#[serde(rename = "sessionID")]
session_id: String,
connection_info: RDMAConnectionInfo,
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
struct UpdateRDMAExportRequestOwned {
#[serde(rename = "sessionID")]
session_id: String,
offset: i64,
length: i64,
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
struct CloseRDMAExportRequestOwned {
#[serde(rename = "sessionID")]
session_id: String,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
struct CloseRDMAExportResponseOwned {
#[serde(rename = "sessionID")]
session_id: String,
closed: bool,
}
fn export_sessions() -> &'static Mutex<HashMap<String, Arc<Mutex<ExportSession>>>> {
EXPORT_SESSIONS.get_or_init(|| Mutex::new(HashMap::new()))
}
fn rdma_endpoint_to_wire(device_name: &str, endpoint: &RdmaEndpointInfo) -> RDMAConnectionInfo {
RDMAConnectionInfo {
device_name: device_name.to_string(),
qpn: endpoint.qpn,
psn: endpoint.psn,
lid: endpoint.lid,
gid: hex::encode(endpoint.gid),
rkey: endpoint.rkey,
addr: endpoint.addr,
length: endpoint.length,
}
}
fn rdma_endpoint_from_wire(value: &RDMAConnectionInfo) -> Result<RdmaEndpointInfo, String> {
let gid_bytes = hex::decode(&value.gid).map_err(|err| format!("decode gid: {err}"))?;
if gid_bytes.len() != 16 {
return Err(format!("gid length must be 16 bytes, got {}", gid_bytes.len()));
}
let mut gid = [0_u8; 16];
gid.copy_from_slice(&gid_bytes);
Ok(RdmaEndpointInfo {
qpn: value.qpn,
psn: value.psn,
lid: value.lid,
gid,
rkey: value.rkey,
addr: value.addr,
length: value.length,
})
}
#[cfg(test)]
mod tests {
use super::{
build_direct_striped_probe_chunk, build_job_groups, build_jobs, desired_worker_count, resolve_source_path_checked,
resolve_relay_root_publisher, select_direct_striped_segments, split_jobs_by_source, ArtifactFile, ByteRange,
CollectiveSpec, LogicalManifest, PushCollectiveChunkRequestOwned, SourceEndpoint, SourceSegmentPlan,
TransferJob, TransferSpec, TransferredChunk,
};
use std::fs;
use std::path::PathBuf;
#[test]
fn export_session_drops_registered_memory_before_transport() {
let source = include_str!("mod.rs");
let start = source.find("struct ExportSession").unwrap();
let end = source[start..].find("\n}").map(|idx| start + idx).unwrap();
let block = &source[start..end];
let slots_pos = block.find("slots: [ExportSlot; 2]").unwrap();
let transport_pos = block.find("transport: RdmaTransport").unwrap();
assert!(
slots_pos < transport_pos,
"ExportSession fields must drop slots/MRs before transport destroys QP/CQ/PD"
);
}
#[test]
fn reusable_rdma_session_drops_local_slots_before_transport() {
let source = include_str!("mod.rs");
let start = source.find("struct ReusableRdmaSession").unwrap();
let end = source[start..].find("\n}").map(|idx| start + idx).unwrap();
let block = &source[start..end];
let slots_pos = block.find("local_slots: Vec<Arc<LocalReadSlot>>").unwrap();
let transport_pos = block.find("transport: RdmaTransport").unwrap();
assert!(
slots_pos < transport_pos,
"ReusableRdmaSession fields must drop local slots/MRs before transport destroys QP/CQ/PD"
);
}
#[test]
fn transfer_spec_accepts_go_crc32c_json_field() {
let spec: TransferSpec = serde_json::from_value(serde_json::json!({
"taskID": "task-crc",
"artifactKey": "artifact",
"transferMode": "SINGLE_SOURCE_DIRECT",
"logicalManifest": {
"artifactKey": "artifact",
"rootPath": "/source",
"chunkSizeBytes": 4,
"digest": "",
"generatedAt": 0,
"files": []
},
"sourceSegments": [],
"targetTempPath": "/target/.staging/artifact",
"targetFinalPath": "/target/artifact",
"enableChunkCRC32C": true,
"chunkSizeBytes": 4
}))
.unwrap();
assert!(spec.enable_chunk_crc32c);
}
#[cfg(unix)]
#[test]
fn resolve_source_path_checked_allows_snapshot_symlink_targets() {
use std::os::unix::fs::symlink;
let root = std::env::temp_dir().join(format!("rdma-engine-test-{}", std::process::id()));
let _ = fs::remove_dir_all(&root);
let blobs = root.join("blobs");
let snapshot = root.join("snapshots").join("rev1");
fs::create_dir_all(&blobs).unwrap();
fs::create_dir_all(&snapshot).unwrap();
fs::write(blobs.join("weights"), b"weights").unwrap();
symlink(PathBuf::from("..").join("..").join("blobs").join("weights"), snapshot.join("model.safetensors"))
.unwrap();
let resolved = resolve_source_path_checked(
snapshot.to_str().unwrap(),
"model.safetensors",
)
.unwrap();
assert_eq!(resolved, snapshot.join("model.safetensors"));
assert_eq!(fs::metadata(resolved).unwrap().len(), 7);
let _ = fs::remove_dir_all(&root);
}
#[test]
fn split_jobs_by_source_keeps_one_group_per_source() {
let source_a = SourceEndpoint {
source_id: "source-a".to_string(),
source_type: "node".to_string(),
endpoint: "http://source-a".to_string(),
path: "/models/a".to_string(),
node_name: Some("m0".to_string()),
};
let source_b = SourceEndpoint {
source_id: "source-b".to_string(),
source_type: "node".to_string(),
endpoint: "http://source-b".to_string(),
path: "/models/b".to_string(),
node_name: Some("n2".to_string()),
};
let jobs = vec![
TransferJob {
source_id: "source-a".to_string(),
source: source_a.clone(),
relative_path: "weights.safetensors".to_string(),
offset: 0,
length: 64,
},
TransferJob {
source_id: "source-a".to_string(),
source: source_a,
relative_path: "weights.safetensors".to_string(),
offset: 64,
length: 64,
},
TransferJob {
source_id: "source-b".to_string(),
source: source_b.clone(),
relative_path: "weights.safetensors".to_string(),
offset: 128,
length: 64,
},
TransferJob {
source_id: "source-b".to_string(),
source: source_b,
relative_path: "weights.safetensors".to_string(),
offset: 192,
length: 64,
},
];
let groups = split_jobs_by_source(jobs, 2);
assert_eq!(groups.len(), 2);
assert_eq!(groups[0].len(), 2);
assert_eq!(groups[1].len(), 2);
assert!(groups[0].iter().all(|job| job.source_id == "source-a"));
assert!(groups[1].iter().all(|job| job.source_id == "source-b"));
}
#[test]
fn direct_striped_build_jobs_uses_manifest_chunk_boundaries() {
let source_a = SourceEndpoint {
source_id: "source-a".to_string(),
source_type: "node".to_string(),
endpoint: "http://source-a".to_string(),
path: "/models/shared".to_string(),
node_name: Some("m0".to_string()),
};
let source_b = SourceEndpoint {
source_id: "source-b".to_string(),
source_type: "node".to_string(),
endpoint: "http://source-b".to_string(),
path: "/models/shared".to_string(),
node_name: Some("n2".to_string()),
};
let spec = TransferSpec {
task_id: "task".to_string(),
artifact_key: "artifact".to_string(),
transfer_mode: "DIRECT_STRIPED".to_string(),
logical_manifest: LogicalManifest {
artifact_key: "artifact".to_string(),
root_path: "/models/shared".to_string(),
chunk_size_bytes: 25,
digest: "digest".to_string(),
generated_at: 0,
files: vec![ArtifactFile {
relative_path: "weights.safetensors".to_string(),
size_bytes: 100,
kind: "SAFETENSORS".to_string(),
chunkable: true,
required: true,
}],
},
collective_spec: CollectiveSpec {
session_id: String::new(),
mode: String::new(),
ring: None,
peers: Vec::new(),
},
source_segments: vec![
SourceSegmentPlan {
source_id: "source-a".to_string(),
source_endpoint: source_a,
byte_ranges: vec![ByteRange {
relative_path: "weights.safetensors".to_string(),
start: 0,
end: 50,
}],
weight: 1,
},
SourceSegmentPlan {
source_id: "source-b".to_string(),
source_endpoint: source_b,
byte_ranges: vec![ByteRange {
relative_path: "weights.safetensors".to_string(),
start: 50,
end: 100,
}],
weight: 1,
},
],
target_temp_path: "/tmp/target".to_string(),
preserve_existing: false,
enable_chunk_crc32c: false,
chunk_size_bytes: 25,
parallelism: 4,
timeout_seconds: 30,
};
let jobs = build_jobs(&spec, &spec.logical_manifest.files).unwrap();
assert_eq!(jobs.len(), 4);
assert!(jobs.iter().all(|job| job.length == 25));
let mut offsets = jobs.iter().map(|job| job.offset).collect::<Vec<_>>();
offsets.sort_unstable();
assert_eq!(offsets, vec![0, 25, 50, 75]);
}
#[test]
fn direct_striped_prefers_single_source_when_one_source_can_saturate_target() {
let spec = TransferSpec {
task_id: "task".to_string(),
artifact_key: "artifact".to_string(),
transfer_mode: "DIRECT_STRIPED".to_string(),
logical_manifest: LogicalManifest {
artifact_key: "artifact".to_string(),
root_path: "/models/shared".to_string(),
chunk_size_bytes: 25,
digest: "digest".to_string(),
generated_at: 0,
files: Vec::new(),
},
collective_spec: CollectiveSpec {
session_id: String::new(),
mode: String::new(),
ring: None,
peers: Vec::new(),
},
source_segments: Vec::new(),
target_temp_path: "/tmp/target".to_string(),
preserve_existing: false,
enable_chunk_crc32c: false,
chunk_size_bytes: 25,
parallelism: 4,
timeout_seconds: 30,
};
let selected = select_direct_striped_segments(
&spec,
vec![
SourceSegmentPlan {
source_id: "source-a".to_string(),
source_endpoint: SourceEndpoint {
source_id: "source-a".to_string(),
source_type: "node".to_string(),
endpoint: "http://source-a".to_string(),
path: "/models/a".to_string(),
node_name: Some("m0".to_string()),
},
byte_ranges: Vec::new(),
weight: 280,
},
SourceSegmentPlan {
source_id: "source-b".to_string(),
source_endpoint: SourceEndpoint {
source_id: "source-b".to_string(),
source_type: "node".to_string(),
endpoint: "http://source-b".to_string(),
path: "/models/b".to_string(),
node_name: Some("n2".to_string()),
},
byte_ranges: Vec::new(),
weight: 220,
},
],
);
assert_eq!(selected.len(), 1);
assert_eq!(selected[0].source_id, "source-a");
}
#[test]
fn direct_striped_keeps_multiple_sources_when_no_single_source_can_saturate_target() {
let spec = TransferSpec {
task_id: "task".to_string(),
artifact_key: "artifact".to_string(),
transfer_mode: "DIRECT_STRIPED".to_string(),
logical_manifest: LogicalManifest {
artifact_key: "artifact".to_string(),
root_path: "/models/shared".to_string(),
chunk_size_bytes: 25,
digest: "digest".to_string(),
generated_at: 0,
files: Vec::new(),
},
collective_spec: CollectiveSpec {
session_id: String::new(),
mode: String::new(),
ring: None,
peers: Vec::new(),
},
source_segments: Vec::new(),
target_temp_path: "/tmp/target".to_string(),
preserve_existing: false,
enable_chunk_crc32c: false,
chunk_size_bytes: 25,
parallelism: 4,
timeout_seconds: 30,
};
let selected = select_direct_striped_segments(
&spec,
vec![
SourceSegmentPlan {
source_id: "source-a".to_string(),
source_endpoint: SourceEndpoint {
source_id: "source-a".to_string(),
source_type: "node".to_string(),
endpoint: "http://source-a".to_string(),
path: "/models/a".to_string(),
node_name: Some("m0".to_string()),
},
byte_ranges: Vec::new(),
weight: 180,
},
SourceSegmentPlan {
source_id: "source-b".to_string(),
source_endpoint: SourceEndpoint {
source_id: "source-b".to_string(),
source_type: "node".to_string(),
endpoint: "http://source-b".to_string(),
path: "/models/b".to_string(),
node_name: Some("n2".to_string()),
},
byte_ranges: Vec::new(),
weight: 170,
},
],
);
assert_eq!(selected.len(), 2);
}
#[test]
fn direct_striped_single_source_uses_parallelism_for_worker_budget() {
let spec = TransferSpec {
task_id: "task".to_string(),
artifact_key: "artifact".to_string(),
transfer_mode: "DIRECT_STRIPED".to_string(),
logical_manifest: LogicalManifest {
artifact_key: "artifact".to_string(),
root_path: "/models/shared".to_string(),
chunk_size_bytes: 25,
digest: "digest".to_string(),
generated_at: 0,
files: Vec::new(),
},
collective_spec: CollectiveSpec {
session_id: String::new(),
mode: String::new(),
ring: None,
peers: Vec::new(),
},
source_segments: Vec::new(),
target_temp_path: "/tmp/target".to_string(),
preserve_existing: false,
enable_chunk_crc32c: false,
chunk_size_bytes: 25,
parallelism: 4,
timeout_seconds: 30,
};
assert_eq!(desired_worker_count(&spec, 8, 1), 4);
assert_eq!(desired_worker_count(&spec, 8, 2), 2);
}
#[test]
fn relay_root_caps_worker_budget_at_two() {
let spec = TransferSpec {
task_id: "task".to_string(),
artifact_key: "artifact".to_string(),
transfer_mode: "PARTIAL_PULL_ALLGATHER".to_string(),
logical_manifest: LogicalManifest {
artifact_key: "artifact".to_string(),
root_path: "/models/shared".to_string(),
chunk_size_bytes: 25,
digest: "digest".to_string(),
generated_at: 0,
files: Vec::new(),
},
collective_spec: CollectiveSpec {
session_id: "plan-collective".to_string(),
mode: "RING".to_string(),
ring: Some(RingPeerPlan {
self_node: "n0".to_string(),
self_endpoint: "http://n0".to_string(),
rank: 0,
world_size: 2,
}),
peers: vec![
CollectivePeerPlan {
node_name: "n0".to_string(),
endpoint: "http://n0".to_string(),
rank: 0,
staging_path: "/tmp/root".to_string(),
owned_ranges: vec![ByteRange {
relative_path: "weights.safetensors".to_string(),
start: 0,
end: 100,
relay_offset: 0,
}],
},
CollectivePeerPlan {
node_name: "n2".to_string(),
endpoint: "http://n2".to_string(),
rank: 1,
staging_path: "/tmp/peer".to_string(),
owned_ranges: Vec::new(),
},
],
},
source_segments: Vec::new(),
target_temp_path: "/tmp/target".to_string(),
preserve_existing: false,
enable_chunk_crc32c: false,
chunk_size_bytes: 25,
parallelism: 4,
timeout_seconds: 30,
};
assert_eq!(desired_worker_count(&spec, 8, 1), 2);
}
#[test]
fn symmetric_fanout_honors_requested_parallelism() {
let spec = TransferSpec {
task_id: "task".to_string(),
artifact_key: "artifact".to_string(),
transfer_mode: "PARTIAL_PULL_ALLGATHER".to_string(),
logical_manifest: LogicalManifest {
artifact_key: "artifact".to_string(),
root_path: "/models/shared".to_string(),
chunk_size_bytes: 25,
digest: "digest".to_string(),
generated_at: 0,
files: Vec::new(),
},
collective_spec: CollectiveSpec {
session_id: "plan-collective".to_string(),
mode: "RING".to_string(),
ring: Some(RingPeerPlan {
self_node: "n0".to_string(),
self_endpoint: "http://n0".to_string(),
rank: 0,
world_size: 2,
}),
peers: vec![
CollectivePeerPlan {
node_name: "n0".to_string(),
endpoint: "http://n0".to_string(),
rank: 0,
staging_path: "/tmp/n0".to_string(),
owned_ranges: vec![ByteRange {
relative_path: "weights.safetensors".to_string(),
start: 0,
end: 100,
relay_offset: 0,
}],
},
CollectivePeerPlan {
node_name: "n2".to_string(),
endpoint: "http://n2".to_string(),
rank: 1,
staging_path: "/tmp/n2".to_string(),
owned_ranges: vec![ByteRange {
relative_path: "weights.safetensors".to_string(),
start: 100,
end: 200,
relay_offset: 0,
}],
},
],
},
source_segments: Vec::new(),
target_temp_path: "/tmp/target".to_string(),
preserve_existing: false,
enable_chunk_crc32c: false,
chunk_size_bytes: 25,
parallelism: 3,
timeout_seconds: 30,
};
assert!(is_ring_relay_collective(&spec));
assert_eq!(desired_worker_count(&spec, 8, 1), 3);
}
#[test]
fn ring_relay_collective_supports_world_size_three() {
let spec = TransferSpec {
task_id: "task".to_string(),
artifact_key: "artifact".to_string(),
transfer_mode: "PARTIAL_PULL_ALLGATHER".to_string(),
logical_manifest: LogicalManifest {
artifact_key: "artifact".to_string(),
root_path: "/models/shared".to_string(),
chunk_size_bytes: 25,
digest: "digest".to_string(),
generated_at: 0,
files: Vec::new(),
},
collective_spec: CollectiveSpec {
session_id: "plan-collective".to_string(),
mode: "RING".to_string(),
ring: Some(RingPeerPlan {
self_node: "n0".to_string(),
self_endpoint: "http://n0".to_string(),
rank: 0,
world_size: 3,
}),
peers: vec![
CollectivePeerPlan {
node_name: "n0".to_string(),
endpoint: "http://n0".to_string(),
rank: 0,
staging_path: "/tmp/n0".to_string(),
owned_ranges: vec![ByteRange {
relative_path: "weights.safetensors".to_string(),
start: 0,
end: 100,
relay_offset: 0,
}],
},
CollectivePeerPlan {
node_name: "n1".to_string(),
endpoint: "http://n1".to_string(),
rank: 1,
staging_path: "/tmp/n1".to_string(),
owned_ranges: vec![ByteRange {
relative_path: "weights.safetensors".to_string(),
start: 100,
end: 200,
relay_offset: 0,
}],
},
CollectivePeerPlan {
node_name: "n2".to_string(),
endpoint: "http://n2".to_string(),
rank: 2,
staging_path: "/tmp/n2".to_string(),
owned_ranges: vec![ByteRange {
relative_path: "weights.safetensors".to_string(),
start: 200,
end: 300,
relay_offset: 0,
}],
},
],
},
source_segments: Vec::new(),
target_temp_path: "/tmp/target".to_string(),
preserve_existing: false,
enable_chunk_crc32c: false,
chunk_size_bytes: 25,
parallelism: 4,
timeout_seconds: 30,
};
let publisher = resolve_relay_root_publisher(&spec).expect("resolve relay publisher");
assert!(is_ring_relay_collective(&spec));
assert_eq!(publisher.session_id, "plan-collective");
assert_eq!(publisher.endpoint, "http://n0");
assert_eq!(desired_worker_count(&spec, 8, 1), 4);
}
#[test]
fn relay_publisher_pairs_two_groups_into_one_iteration() {
let publisher = RelayBatchPublisher::new(
Client::builder().build().unwrap(),
RelayRootPublisher {
task_id: "task".to_string(),
session_id: "session".to_string(),
source_node: "n0".to_string(),
endpoint: "http://n2".to_string(),
},
"/tmp/target".to_string(),
2,
1,
RELAY_PUBLISH_BATCH_BYTES,
);
let outcome = JobOutcome {
chunk: TransferredChunk {
chunk_id: "chunk-0".to_string(),
file_path: "/tmp/target/blob4g.safetensors".to_string(),
relative_path: "blob4g.safetensors".to_string(),
offset: 0,
size: 64,
crc32c: String::new(),
source_id: "source-0".to_string(),
},
transport: PayloadTransport::Rdma,
read_ms: 1,
write_ms: 1,
total_ms: 2,
};
assert!(publisher.enqueue(vec![outcome.clone()], false).unwrap().is_none());
let ready = publisher.enqueue(vec![outcome], false).unwrap().unwrap();
assert_eq!(ready.0, 1);
assert_eq!(ready.1.len(), 2);
assert_eq!(ready.1[0].len(), 1);
assert_eq!(ready.1[1].len(), 1);
assert_eq!(ready.2, 128);
}
#[test]
fn direct_striped_single_source_uses_contiguous_job_groups() {
let source = SourceEndpoint {
source_id: "source-a".to_string(),
source_type: "node".to_string(),
endpoint: "http://source-a".to_string(),
path: "/models/a".to_string(),
node_name: Some("m0".to_string()),
};
let jobs = vec![
TransferJob {
source_id: "source-a".to_string(),
source: source.clone(),
relative_path: "weights.safetensors".to_string(),
offset: 0,
length: 64,
},
TransferJob {
source_id: "source-a".to_string(),
source: source.clone(),
relative_path: "weights.safetensors".to_string(),
offset: 64,
length: 64,
},
TransferJob {
source_id: "source-a".to_string(),
source: source.clone(),
relative_path: "weights.safetensors".to_string(),
offset: 128,
length: 64,
},
TransferJob {
source_id: "source-a".to_string(),
source,
relative_path: "weights.safetensors".to_string(),
offset: 192,
length: 64,
},
];
let groups = build_job_groups("DIRECT_STRIPED", jobs, 2, 1);
assert_eq!(groups.len(), 2);
assert_eq!(groups[0].len(), 2);
assert_eq!(groups[1].len(), 2);
}
#[test]
fn direct_striped_probe_chunk_uses_small_4m_window() {
let files = vec![
ArtifactFile {
relative_path: "weights.safetensors".to_string(),
size_bytes: 16 * 1024 * 1024,
kind: "SAFETENSORS".to_string(),
chunkable: true,
required: true,
},
ArtifactFile {
relative_path: "small.safetensors".to_string(),
size_bytes: 2 * 1024 * 1024,
kind: "SAFETENSORS".to_string(),
chunkable: true,
required: true,
},
];
let probe = build_direct_striped_probe_chunk(&files).unwrap();
assert_eq!(probe.relative_path, "weights.safetensors");
assert_eq!(probe.offset, 0);
assert_eq!(probe.length, 4 * 1024 * 1024);
}
#[test]
fn relay_collective_spec_preserves_session_id() {
let request = serde_json::json!({
"taskID": "plan-n0",
"artifactKey": "artifact",
"transferMode": "PARTIAL_PULL_ALLGATHER",
"logicalManifest": {
"artifactKey": "artifact",
"rootPath": "/models/qwen",
"chunkSizeBytes": 67108864_i64,
"digest": "digest",
"generatedAt": 0,
"files": []
},
"collectiveSpec": {
"sessionID": "plan-collective",
"mode": "RING",
"ring": {
"selfNode": "n0",
"selfEndpoint": "192.168.200.25",
"rank": 0,
"worldSize": 2
},
"peers": [{
"nodeName": "n0",
"endpoint": "192.168.200.25",
"rank": 0,
"stagingPath": "/var/lib/weight-dispatcher/cache/staging/artifact",
"ownedRanges": [{
"relativePath": "weights.bin",
"start": 0,
"end": 1024
}]
}, {
"nodeName": "n2",
"endpoint": "192.168.200.14",
"rank": 1,
"stagingPath": "/var/lib/weight-dispatcher/cache/staging/artifact",
"ownedRanges": []
}]
},
"sourceSegments": [{
"sourceID": "source-0",
"sourceEndpoint": {
"sourceID": "source-0",
"sourceType": "node",
"endpoint": "http://192.168.200.15:18080",
"path": "/models/qwen",
"nodeName": "m0"
},
"byteRanges": [{
"relativePath": "weights.bin",
"start": 0,
"end": 1024
}]
}],
"targetTempPath": "/tmp/target",
"chunkSizeBytes": 67108864_i64,
"parallelism": 4,
"timeoutSeconds": 30
})
.to_string();
let spec: TransferSpec = serde_json::from_str(&request).expect("parse transfer spec");
let publisher = resolve_relay_root_publisher(&spec).expect("resolve relay publisher");
assert_eq!(spec.collective_spec.session_id, "plan-collective");
assert_eq!(publisher.session_id, "plan-collective");
}
#[test]
fn relay_collective_push_serializes_session_id_in_go_wire_format() {
let payload = serde_json::to_value(&PushCollectiveChunkRequestOwned {
task_id: "plan-n0".to_string(),
session_id: "plan-collective".to_string(),
iteration: 1,
chunk: TransferredChunk {
chunk_id: "weights.bin:0".to_string(),
file_path: "/tmp/target/weights.bin".to_string(),
relative_path: "weights.bin".to_string(),
offset: 0,
size: 1024,
crc32c: String::new(),
source_id: "n0".to_string(),
},
data: None,
transport_path: "RDMA".to_string(),
source_node: "n0".to_string(),
relay_offset: 0,
relay_rdma: Some(RelayRdmaHint {
session_id: "relay-session".to_string(),
connection_info: RDMAConnectionInfo {
mode: "RDMA_RC".to_string(),
addr: 123,
rkey: 456,
qp_num: 789,
lid: 1,
gid: String::new(),
length: 1024,
},
}),
})
.expect("serialize push request");
assert_eq!(payload.get("sessionID").and_then(|value| value.as_str()), Some("plan-collective"));
assert!(payload.get("relayRDMA").is_some());
assert!(payload.get("relayRdma").is_none());
assert!(payload.get("sessionId").is_none());
}
#[test]
fn source_endpoint_deserializes_go_relay_rdma_field() {
let payload = serde_json::json!({
"sourceID": "relay-source",
"sourceType": "peer-relay",
"endpoint": "http://n0:18080",
"path": "/var/lib/weight-dispatcher/cache/staging/blob4g",
"nodeName": "n0",
"relayRDMA": {
"sessionID": "relay-session",
"connectionInfo": {
"mode": "RDMA_RC",
"addr": 123,
"rkey": 456,
"qpNum": 789,
"lid": 1,
"gid": "",
"length": 1024
}
}
})
.to_string();
let endpoint: SourceEndpoint = serde_json::from_str(&payload).expect("deserialize source endpoint");
assert_eq!(endpoint.source_id, "relay-source");
assert_eq!(
endpoint
.relay_rdma
.as_ref()
.map(|hint| hint.session_id.as_str()),
Some("relay-session")
);
}
#[test]
fn normalize_agent_endpoint_uses_configured_agent_port() {
std::env::set_var("AGENT_PORT", "19090");
let endpoint = normalize_agent_endpoint("node-a").expect("normalize endpoint");
std::env::remove_var("AGENT_PORT");
assert_eq!(endpoint, "http://node-a:19090");
}
}