use std::io;
use ibverbs::{
CompletionQueue, Context, MemoryRegion, PreparedQueuePair, ProtectionDomain, QueuePair,
QueuePairBuilder, QueuePairEndpoint, RemoteMemoryRegion, devices, ibv_qp_type,
};
use crate::shmq::GeneralShmQueueBytes;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RdmaQueueError {
Full,
Empty,
IoError,
}
impl std::error::Error for RdmaQueueError {}
impl std::fmt::Display for RdmaQueueError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
RdmaQueueError::Full => write!(f, "Queue is full"),
RdmaQueueError::Empty => write!(f, "Queue is empty"),
RdmaQueueError::IoError => write!(f, "RDMA I/O error"),
}
}
}
impl From<RdmaQueueError> for io::Error {
fn from(err: RdmaQueueError) -> Self {
match err {
RdmaQueueError::Full => io::Error::new(io::ErrorKind::WouldBlock, "Queue is full"),
RdmaQueueError::Empty => io::Error::new(io::ErrorKind::WouldBlock, "Queue is empty"),
RdmaQueueError::IoError => io::Error::other("RDMA I/O error"),
}
}
}
#[repr(C)]
#[derive(Debug, Clone, Copy, Default)]
struct RdmaQueueMeta {
capacity: usize,
head: usize,
tail: usize,
data_offset: usize,
ready: bool,
}
pub struct RdmaQueue<T> {
_context: Context,
_pd: ProtectionDomain,
_recv_cq: CompletionQueue,
send_cq: CompletionQueue,
qp_builder: QueuePairBuilder,
endpoint: QueuePairEndpoint,
prepared_qp: Option<PreparedQueuePair>,
qp: Option<QueuePair>,
memory_region: MemoryRegion<Vec<u8>>,
remote_region: Option<RemoteMemoryRegion>,
capacity: usize,
is_sender: bool,
wr_id_counter: u64,
_phantom: std::marker::PhantomData<T>,
}
fn offset_of_head() -> usize {
let dummy = RdmaQueueMeta::default();
let base_ptr = &dummy as *const _ as usize;
let field_ptr = &dummy.head as *const _ as usize;
field_ptr - base_ptr
}
fn offset_of_tail() -> usize {
let dummy = RdmaQueueMeta::default();
let base_ptr = &dummy as *const _ as usize;
let field_ptr = &dummy.tail as *const _ as usize;
field_ptr - base_ptr
}
impl<T: GeneralShmQueueBytes> RdmaQueue<T> {
pub fn is_connected(&self) -> bool {
self.qp.is_some()
}
pub fn new(device_index: Option<usize>, capacity: usize, is_sender: bool) -> io::Result<Self> {
let devices = devices()?;
if devices.is_empty() {
return Err(io::Error::other("No RDMA devices found"));
}
let device = devices
.get(device_index.unwrap_or(0))
.ok_or_else(|| io::Error::other("Invalid device index"))?;
let context = device.open()?;
let pd = context.alloc_pd()?;
let send_cq = context.create_cq(256, 0)?;
let recv_cq = context.create_cq(256, 1)?;
let meta_size = std::mem::size_of::<RdmaQueueMeta>();
let data_offset = meta_size.next_multiple_of(T::aligned_size());
let total_size = data_offset + capacity * T::aligned_size();
let mut memory_region = pd.allocate(total_size)?;
let meta = RdmaQueueMeta {
capacity,
head: 0,
tail: 0,
data_offset,
ready: true,
};
let meta_bytes =
unsafe { std::slice::from_raw_parts(&meta as *const _ as *const u8, meta_size) };
memory_region.inner()[..meta_size].copy_from_slice(meta_bytes);
let mut qp_builder = pd.create_qp(&send_cq, &recv_cq, ibv_qp_type::IBV_QPT_RC)?;
qp_builder
.set_gid_index(0)
.set_max_send_wr(256)
.set_max_recv_wr(256)
.set_max_send_sge(1)
.set_max_recv_sge(1)
.allow_remote_rw();
let prepared_qp = qp_builder.build()?;
let endpoint = prepared_qp.endpoint()?;
Ok(Self {
_context: context,
_pd: pd,
_recv_cq: recv_cq,
send_cq,
qp_builder,
endpoint,
prepared_qp: Some(prepared_qp),
qp: None,
memory_region,
remote_region: None,
capacity,
is_sender,
wr_id_counter: 1,
_phantom: std::marker::PhantomData,
})
}
pub fn endpoint(&self) -> io::Result<QueuePairEndpoint> {
Ok(self.endpoint.clone())
}
pub fn memory_region(&self) -> RemoteMemoryRegion {
self.memory_region.remote()
}
pub fn connect(
&mut self,
remote_endpoint: QueuePairEndpoint,
remote_region: RemoteMemoryRegion,
) -> io::Result<()> {
self.remote_region = Some(remote_region);
if !self.is_connected() {
let prepared_qp = self
.prepared_qp
.take()
.ok_or_else(|| io::Error::other("No prepared QP available"))?;
let result = prepared_qp.handshake(remote_endpoint);
let qp = result.map_err(|e| io::Error::other(format!("QP handshake failed: {}", e)))?;
self.qp = Some(qp);
log::info!("🚀Connected to remote peer: {:?}", remote_endpoint);
Ok(())
} else {
Err(io::Error::other(
"Queue pair already connected or not prepared",
))
}
}
pub fn send(&mut self, item: &T) -> Result<(), RdmaQueueError> {
if !self.is_sender || !self.is_connected() {
return Err(RdmaQueueError::IoError);
}
let start_time = std::time::Instant::now();
let mut t = std::time::Instant::now();
let remote_region = self.remote_region.clone().ok_or(RdmaQueueError::IoError)?;
log::debug!("🚀 1. region clone: {:?}", t.elapsed());
t = std::time::Instant::now();
let meta = self.read_remote_meta(&remote_region)?;
log::debug!("🚀 2. read remote meta: {:?}", t.elapsed());
if meta.capacity == 0 || !meta.ready {
return Err(RdmaQueueError::IoError);
}
if (meta.tail + 1) % meta.capacity == meta.head {
return Err(RdmaQueueError::Full);
}
t = std::time::Instant::now();
log::debug!("🚀 3. start write remote data");
let item_offset = meta.data_offset + meta.tail * T::aligned_size();
self.write_remote_data(&remote_region, item_offset, item)?;
log::debug!("🚀 - 3 end write remote data: {:?}\n", t.elapsed());
t = std::time::Instant::now();
log::debug!("🚀 4. start update remote tail");
let new_tail = (meta.tail + 1) % meta.capacity;
self.update_remote_tail(&remote_region, new_tail)?;
log::debug!("🚀 - 4 end update remote tail: {:?}\n", t.elapsed());
log::debug!("🚀 Total send time: {:?}\n", start_time.elapsed());
Ok(())
}
pub fn recv(&mut self) -> Result<T, RdmaQueueError> {
if self.is_sender || !self.is_connected() {
return Err(RdmaQueueError::IoError);
}
let meta = self.read_local_meta();
if meta.head == meta.tail {
return Err(RdmaQueueError::Empty);
}
let item_offset = meta.data_offset + meta.head * T::aligned_size();
let item_data = &self.memory_region.inner()[item_offset..item_offset + T::CAPACITY];
let item = T::from_bytes(item_data);
let new_head = (meta.head + 1) % meta.capacity;
self.update_local_head(new_head)?;
Ok(item)
}
fn read_local_meta(&mut self) -> RdmaQueueMeta {
let meta_size = std::mem::size_of::<RdmaQueueMeta>();
let meta_bytes = &self.memory_region.inner()[..meta_size];
unsafe { std::ptr::read(meta_bytes.as_ptr() as *const RdmaQueueMeta) }
}
fn update_local_head(&mut self, new_head: usize) -> Result<(), RdmaQueueError> {
let head_offset = offset_of_head();
let head_bytes = new_head.to_le_bytes();
let memory = self.memory_region.inner();
memory[head_offset..head_offset + std::mem::size_of::<usize>()]
.copy_from_slice(&head_bytes);
Ok(())
}
fn write_remote_data(
&mut self,
remote_region: &RemoteMemoryRegion,
offset: usize,
item: &T,
) -> Result<(), RdmaQueueError> {
let write_offset = std::mem::size_of::<RdmaQueueMeta>().next_multiple_of(64);
let mut t = std::time::Instant::now();
item.write_to_slice(
&mut self.memory_region.inner()[write_offset..write_offset + T::CAPACITY],
);
log::debug!("🚀 3.1 write to local buffer: {:?}", t.elapsed());
t = std::time::Instant::now();
let local_slice = self
.memory_region
.slice(write_offset..write_offset + item.len());
let remote_slice = remote_region.slice(offset..offset + item.len());
log::debug!("🚀 3.2 prepare slices: {:?}", t.elapsed());
t = std::time::Instant::now();
let wr_id = self.next_wr_id();
let qp = self.qp.as_mut().ok_or(RdmaQueueError::IoError)?;
qp.post_write(&[local_slice], remote_slice, wr_id, None)
.map_err(|_| RdmaQueueError::IoError)?;
log::debug!("🚀 3.3 post write: {:?}", t.elapsed());
t = std::time::Instant::now();
self.wait_for_completion(wr_id)?;
log::debug!("🚀 3.4 wait for completion: {:?}", t.elapsed());
Ok(())
}
fn update_remote_tail(
&mut self,
remote_region: &RemoteMemoryRegion,
new_tail: usize,
) -> Result<(), RdmaQueueError> {
let tail_offset = offset_of_tail();
let tail_size = std::mem::size_of::<usize>();
let mut t = std::time::Instant::now();
let write_offset =
std::mem::size_of::<RdmaQueueMeta>().next_multiple_of(64) + T::aligned_size();
let tail_bytes = new_tail.to_le_bytes();
log::debug!("🚀 4.1 prepare tail bytes: {:?}", t.elapsed());
t = std::time::Instant::now();
self.memory_region.inner()[write_offset..write_offset + tail_size]
.copy_from_slice(&tail_bytes);
log::debug!("🚀 4.2 write to local buffer: {:?}", t.elapsed());
t = std::time::Instant::now();
let local_slice = self
.memory_region
.slice(write_offset..write_offset + tail_size);
let remote_slice = remote_region.slice(tail_offset..tail_offset + tail_size);
log::debug!("🚀 4.3 prepare slices: {:?}", t.elapsed());
t = std::time::Instant::now();
let wr_id = self.next_wr_id();
let qp = self.qp.as_mut().ok_or(RdmaQueueError::IoError)?;
qp.post_write(&[local_slice], remote_slice, wr_id, None)
.map_err(|_| RdmaQueueError::IoError)?;
log::debug!("🚀 4.4 post write: {:?}", t.elapsed());
t = std::time::Instant::now();
self.wait_for_completion(wr_id)?;
log::debug!("🚀 4.5 wait for completion: {:?}", t.elapsed());
Ok(())
}
fn read_remote_meta(
&mut self,
remote_region: &RemoteMemoryRegion,
) -> Result<RdmaQueueMeta, RdmaQueueError> {
let meta_size = std::mem::size_of::<RdmaQueueMeta>();
let local_slice = self.memory_region.slice(0..meta_size);
let remote_slice = remote_region.slice(0..meta_size);
let wr_id = self.next_wr_id();
let qp = self.qp.as_mut().ok_or(RdmaQueueError::IoError)?;
qp.post_read(&[local_slice], remote_slice, wr_id)
.map_err(|_| RdmaQueueError::IoError)?;
self.wait_for_completion(wr_id)?;
let meta_bytes = &self.memory_region.inner()[0..meta_size];
let meta = unsafe { std::ptr::read(meta_bytes.as_ptr() as *const RdmaQueueMeta) };
Ok(meta)
}
fn wait_for_completion(&mut self, expected_wr_id: u64) -> Result<(), RdmaQueueError> {
let mut completions = [Default::default(); 4];
loop {
match self.send_cq.wait(&mut completions, None) {
Ok(completed) => {
if !completed.is_empty() {
for completion in completed {
if completion.wr_id() == expected_wr_id {
return Ok(());
}
}
}
}
Err(_) => return Err(RdmaQueueError::IoError),
}
}
}
fn next_wr_id(&mut self) -> u64 {
let id = self.wr_id_counter;
self.wr_id_counter += 1;
id
}
pub fn capacity(&self) -> usize {
self.capacity
}
pub fn try_send(&mut self, item: &T) -> Result<(), RdmaQueueError> {
self.send(item)
}
pub fn try_recv(&mut self) -> Result<T, RdmaQueueError> {
self.recv()
}
pub fn disconnect(&mut self) {
if let Some(qp) = self.qp.take() {
drop(qp);
}
if self.prepared_qp.is_none() {
let new_prepared_qp = self
.qp_builder
.build()
.unwrap_or_else(|e| panic!("Failed to build new prepared QP: {}", e));
self.prepared_qp = Some(new_prepared_qp);
self.endpoint = self.prepared_qp.as_ref().unwrap().endpoint().unwrap();
}
}
}
impl<T> Drop for RdmaQueue<T> {
fn drop(&mut self) {
if let Some(qp) = self.qp.take() {
drop(qp);
}
}
}
impl GeneralShmQueueBytes for u64 {
const CAPACITY: usize = std::mem::size_of::<u64>();
fn write_to_slice(&self, slice: &mut [u8]) {
slice[..Self::CAPACITY].copy_from_slice(&self.to_le_bytes());
}
fn from_bytes(bytes: &[u8]) -> Self {
u64::from_le_bytes(bytes[..Self::CAPACITY].try_into().unwrap())
}
fn len(&self) -> usize {
std::mem::size_of::<u64>()
}
}
impl GeneralShmQueueBytes for String {
const CAPACITY: usize = 256;
fn write_to_slice(&self, slice: &mut [u8]) {
slice.fill(0);
let bytes = self.as_bytes();
let len = bytes.len().min(Self::CAPACITY - 1);
slice[..len].copy_from_slice(&bytes[..len]);
}
fn from_bytes(bytes: &[u8]) -> Self {
let end = bytes.iter().position(|&b| b == 0).unwrap_or(bytes.len());
String::from_utf8_lossy(&bytes[..end]).into_owned()
}
fn len(&self) -> usize {
self.len()
}
}