use std::{
cell::UnsafeCell,
collections::HashMap,
sync::{OnceLock, RwLock},
};
use bytes::Bytes;
use memmap2::Mmap;
use safetensors::SafeTensors;
pub struct MemCache {
map: UnsafeCell<HashMap<String, Bytes>>,
mu: RwLock<()>,
}
impl Default for MemCache {
fn default() -> Self {
Self::new()
}
}
impl MemCache {
pub fn new() -> Self {
Self {
map: UnsafeCell::new(HashMap::new()),
mu: RwLock::new(()),
}
}
pub fn contains_key(&self, key: &str) -> bool {
let _lg = self.mu.read().unwrap();
let m = unsafe { &(*self.map.get()) };
m.contains_key(key)
}
pub fn get_ref<'a>(&'a self, key: &str) -> Option<&'a [u8]> {
let _lg = self.mu.read().unwrap();
let m = unsafe { &(*self.map.get()) };
Some(m.get(key)?.as_ref())
}
pub fn insert(&self, key: &str, value: Bytes) {
let _wg = self.mu.write().unwrap();
let m = unsafe { &mut (*self.map.get()) };
m.insert(key.to_owned(), value);
}
pub fn remove(&self, key: &str) {
let _wg = self.mu.write().unwrap();
let m = unsafe { &mut (*self.map.get()) };
m.remove(key);
}
}
unsafe impl Send for MemCache {}
unsafe impl Sync for MemCache {}
pub struct SafetensorCache<'a> {
map: UnsafeCell<HashMap<String, SafeTensorWithData<'a>>>,
lk: RwLock<()>,
}
impl Default for SafetensorCache<'_> {
fn default() -> Self {
Self::new()
}
}
impl<'data> SafetensorCache<'data> {
pub fn new() -> Self {
Self {
map: UnsafeCell::new(HashMap::new()),
lk: RwLock::new(()),
}
}
pub fn insert(&self, key: &str, value: SafeTensorWithData<'data>) {
let _lg = self.lk.write().unwrap();
let m = unsafe { &mut (*self.map.get()) };
if m.contains_key(key) {
return;
}
m.insert(key.to_string(), value);
}
pub fn get(&self, key: &str) -> Option<&SafeTensors<'data>> {
let _lg = self.lk.read().unwrap();
let m = unsafe { &(*self.map.get()) };
let v = m.get(key);
v.map(|x| x.safetensors())
}
pub fn contains_key(&self, key: &str) -> bool {
let _lg = self.lk.read().unwrap();
unsafe { (*self.map.get()).contains_key(key) }
}
}
unsafe impl Send for SafetensorCache<'_> {}
unsafe impl Sync for SafetensorCache<'_> {}
#[derive(Debug)]
pub struct SafeTensorWithData<'data> {
st: OnceLock<SafeTensors<'data>>,
mmap: Mmap,
}
impl<'data> SafeTensorWithData<'data> {
pub fn new(mmap: Mmap) -> Self {
Self {
st: OnceLock::new(),
mmap,
}
}
pub fn safetensors(&'data self) -> &'data SafeTensors<'data> {
(self
.st
.get_or_init(|| safetensors::SafeTensors::deserialize(&self.mmap).unwrap()))
as _
}
}
impl Drop for SafeTensorWithData<'_> {
fn drop(&mut self) {
}
}