use std::{ffi::c_void, num::NonZero, ptr::NonNull};

use nix::{
    fcntl::OFlag,
    sys::{
        mman,
        stat::{self, Mode},
    },
    unistd,
};

use crate::shmq::GeneralShmQueueBytes;

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ShmQueueError {
    Full,
    Empty,
}

impl std::error::Error for ShmQueueError {}

impl std::fmt::Display for ShmQueueError {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        match self {
            ShmQueueError::Full => write!(f, "Queue is full"),
            ShmQueueError::Empty => write!(f, "Queue is empty"),
        }
    }
}

pub struct ShmQueueMeta {
    capacity: usize,
    head: usize,
    tail: usize,
    data_offset: usize,
    ready: bool,
}

pub struct ShmQueue<'a, T> {
    owned: bool,
    name: String,
    mmap: (NonNull<c_void>, usize),
    meta: &'a mut ShmQueueMeta,
    data: &'a mut [u8],
    _phantom: std::marker::PhantomData<T>,
}

impl<'a, T: GeneralShmQueueBytes> ShmQueue<'a, T> {
    pub fn new(name: &str, capacity: usize) -> Self {
        let meta_vs_data = std::mem::size_of::<ShmQueueMeta>() / T::aligned_size();
        let len = (capacity + 1 + meta_vs_data) * T::aligned_size();
        let shm = mman::shm_open(
            name,
            OFlag::O_CREAT | OFlag::O_RDWR | OFlag::O_EXCL,
            Mode::S_IRUSR | Mode::S_IWUSR,
        )
        .unwrap();

        unistd::ftruncate(&shm, len as _).unwrap();
        let mmap = unsafe {
            mman::mmap(
                None,
                NonZero::new(len).unwrap(),
                mman::ProtFlags::PROT_READ | mman::ProtFlags::PROT_WRITE,
                mman::MapFlags::MAP_SHARED,
                &shm,
                0,
            )
            .unwrap()
            .cast::<u8>()
        };

        let meta = unsafe { &mut *(mmap.as_ptr() as *mut ShmQueueMeta) };
        meta.capacity = capacity;
        meta.head = 0;
        meta.tail = 0;
        meta.data_offset = 1 + meta_vs_data;
        meta.ready = true;

        let data = unsafe {
            std::slice::from_raw_parts_mut(
                mmap.as_ptr().add(meta.data_offset * T::aligned_size()),
                capacity * T::aligned_size(),
            )
        };

        Self {
            owned: true,
            name: name.to_string(),
            mmap: (mmap.cast(), len),
            meta,
            data,
            _phantom: std::marker::PhantomData,
        }
    }

    pub fn open(name: &str) -> Option<Self> {
        let Ok(shm) = mman::shm_open(name, OFlag::O_RDWR, Mode::S_IRUSR | Mode::S_IWUSR) else {
            return None;
        };

        let len = stat::fstat(&shm).unwrap().st_size as usize;
        if len < std::mem::size_of::<ShmQueueMeta>() {
            return None;
        }

        let mmap = unsafe {
            mman::mmap(
                None,
                NonZero::new(len).unwrap(),
                mman::ProtFlags::PROT_READ | mman::ProtFlags::PROT_WRITE,
                mman::MapFlags::MAP_SHARED,
                &shm,
                0,
            )
            .unwrap()
            .cast::<u8>()
        };

        let meta = unsafe { &mut *(mmap.as_ptr() as *mut ShmQueueMeta) };
        while !unsafe { std::ptr::read_volatile(&meta.ready) } {
            std::thread::sleep(std::time::Duration::from_micros(100));
        }

        let data = unsafe {
            std::slice::from_raw_parts_mut(
                mmap.as_ptr().add(meta.data_offset * T::aligned_size()),
                meta.capacity * T::aligned_size(),
            )
        };

        Some(Self {
            owned: false,
            name: name.to_string(),
            mmap: (mmap.cast(), len),
            meta,
            data,
            _phantom: std::marker::PhantomData,
        })
    }

    pub fn capacity(&self) -> usize {
        self.meta.capacity
    }

    pub fn send(&mut self, item: &T) -> Result<(), ShmQueueError> {
        let meta = unsafe { std::ptr::read_volatile(self.meta) };
        if meta.head == (meta.tail + 1) % meta.capacity {
            return Err(ShmQueueError::Full);
        }

        item.write_to_slice(
            &mut self.data[meta.tail * T::aligned_size()..(meta.tail + 1) * T::aligned_size()],
        );

        self.meta.tail = (self.meta.tail + 1) % self.meta.capacity;
        Ok(())
    }

    pub fn recv(&mut self) -> Result<T, ShmQueueError> {
        let meta = unsafe { std::ptr::read_volatile(self.meta) };
        if meta.head == meta.tail {
            return Err(ShmQueueError::Empty);
        }

        let item = T::from_bytes(
            &self.data[meta.head * T::aligned_size()..(meta.head + 1) * T::aligned_size()],
        );
        self.meta.head = (self.meta.head + 1) % self.meta.capacity;
        Ok(item)
    }
}

impl<'a, T> Drop for ShmQueue<'a, T> {
    fn drop(&mut self) {
        let (addr, len) = self.mmap;
        unsafe { mman::munmap(addr, len).unwrap() };
        if self.owned {
            let _ = mman::shm_unlink(self.name.as_str());
        }
    }
}

unsafe impl<'a, T> Send for ShmQueue<'a, T> {}

#[cfg(test)]
mod test {
    use crate::shmq::GeneralShmQueueBytes;

    use super::*;

    impl GeneralShmQueueBytes for i32 {
        const CAPACITY: usize = std::mem::size_of::<i32>();

        fn write_to_slice(&self, slice: &mut [u8]) {
            slice[..4].copy_from_slice(&self.to_le_bytes());
        }

        fn from_bytes(bytes: &[u8]) -> Self {
            i32::from_le_bytes(bytes.try_into().unwrap())
        }

        fn len(&self) -> usize {
            Self::CAPACITY
        }
    }

    #[test]
    fn test_queue() {
        let mut sender = ShmQueue::new("test_queue", 10);
        assert_eq!(sender.capacity(), 10);

        assert!(sender.send(&1).is_ok());
        assert!(sender.send(&2).is_ok());
        assert_eq!(sender.recv(), Ok(1));
        assert_eq!(sender.recv(), Ok(2));
        assert_eq!(sender.recv(), Err(ShmQueueError::Empty));

        let mut receiver = ShmQueue::open("test_queue").unwrap();
        assert_eq!(receiver.capacity(), 10);

        assert!(receiver.send(&3).is_ok());
        assert_eq!(receiver.recv(), Ok(3));

        assert!(sender.send(&4).is_ok());
        assert_eq!(receiver.recv(), Ok(4));

        let mut counter = 0;
        while sender.send(&5).is_ok() {
            counter += 1;
        }
        assert_eq!(sender.send(&5), Err(ShmQueueError::Full));

        while let Ok(item) = receiver.recv() {
            counter -= 1;
            assert_eq!(item, 5);
        }
        assert_eq!(counter, 0);
        assert_eq!(receiver.recv(), Err(ShmQueueError::Empty));
    }

    #[test]
    fn test_raii() {
        let owner = ShmQueue::<i32>::new("test_queue", 10);
        assert!(ShmQueue::<i32>::open("test_queue").is_some());
        drop(owner);
        assert!(ShmQueue::<i32>::open("test_queue").is_none());
    }
}