use arroy::distances::DotProduct;
use heed::{types::*, RwTxn};
use std::fmt::Debug;
use std::sync::atomic::AtomicUsize;
use arroy::{Database as ArroyDatabase, Reader, Writer};
use heed::types::SerdeJson;
use heed::{Database, EnvOpenOptions};
use kalosm_language_model::*;
use rand::rngs::StdRng;
use rand::SeedableRng;
use serde::{Deserialize, Serialize};
pub type Candidates = roaring::RoaringBitmap;
#[derive(Debug, thiserror::Error)]
pub enum VectorDbError {
#[error("Arroy error: {0}")]
Arroy(#[from] arroy::Error),
#[error("Embedding {0:?} not found")]
EmbeddingNotFound(EmbeddingId),
}
impl From<heed::Error> for VectorDbError {
fn from(value: heed::Error) -> Self {
Self::Arroy(value.into())
}
}
#[doc(alias = "VectorDatabase")]
#[doc(alias = "Vector Database")]
pub struct VectorDB {
database: ArroyDatabase<DotProduct>,
metadata: Database<Str, SerdeJson<Vec<u32>>>,
env: heed::Env,
dim: AtomicUsize,
}
impl Default for VectorDB {
fn default() -> Self {
Self::new().unwrap()
}
}
impl VectorDB {
fn set_dim(&self, dim: usize) {
if dim == 0 {
panic!("Dimension cannot be 0");
}
self.dim.store(dim, std::sync::atomic::Ordering::Relaxed);
}
fn get_dim(&self) -> Result<usize, arroy::Error> {
let mut dims = self.dim.load(std::sync::atomic::Ordering::Relaxed);
if dims == 0 {
let rtxn = self.env.read_txn()?;
let reader = Reader::<DotProduct>::open(&rtxn, 0, self.database)?;
dims = reader.dimensions();
self.set_dim(dims);
}
Ok(dims)
}
#[tracing::instrument]
pub fn new() -> heed::Result<Self> {
let dir = tempfile::tempdir()?;
Self::new_at(dir.path())
}
pub fn new_at(path: impl AsRef<std::path::Path>) -> heed::Result<Self> {
const TWENTY_HUNDRED_MIB: usize = 2 * 1024 * 1024 * 1024;
std::fs::create_dir_all(&path)?;
let env = unsafe {
EnvOpenOptions::new()
.map_size(TWENTY_HUNDRED_MIB)
.open(path)
}?;
let mut wtxn = env.write_txn()?;
let db: ArroyDatabase<DotProduct> = env.create_database(&mut wtxn, None)?;
let metadata: Database<Str, SerdeJson<Vec<u32>>> = env.create_database(&mut wtxn, None)?;
wtxn.commit()?;
Ok(Self {
database: db,
metadata,
env,
dim: AtomicUsize::new(0),
})
}
fn take_id(&self, wtxn: &mut RwTxn) -> Result<EmbeddingId, heed::Error> {
if let Some(mut free) = self.metadata.get(wtxn, "free")? {
if let Some(id) = free.pop() {
self.metadata.put(wtxn, "free", &free)?;
return Ok(EmbeddingId(id));
}
}
match self.metadata.get(wtxn, "max")? {
Some(max) => {
let id = max[0];
self.metadata.put(wtxn, "max", &vec![id + 1])?;
Ok(EmbeddingId(id))
}
None => {
self.metadata.put(wtxn, "max", &vec![1])?;
Ok(EmbeddingId(0))
}
}
}
fn recycle_id(&self, id: EmbeddingId, wtxn: &mut RwTxn) -> Result<(), heed::Error> {
let mut free = self.metadata.get(wtxn, "free")?.unwrap_or_default();
free.push(id.0);
self.metadata.put(wtxn, "free", &free)?;
Ok(())
}
pub fn raw(&self) -> (&ArroyDatabase<DotProduct>, &heed::Env) {
(&self.database, &self.env)
}
pub async fn clear(&self) -> Result<(), arroy::Error> {
let mut wtxn = self.env.write_txn()?;
let dims = self.get_dim()?;
let writer = Writer::<DotProduct>::new(self.database, 0, dims);
writer.clear(&mut wtxn)?;
self.metadata.put(&mut wtxn, "max", &vec![0])?;
self.metadata.put(&mut wtxn, "free", &vec![])?;
wtxn.commit()?;
Ok(())
}
pub fn rebuild(
&self,
writer: &mut Writer<DotProduct>,
wtxn: &mut RwTxn,
) -> Result<(), arroy::Error> {
let mut rng = StdRng::from_entropy();
writer.builder(&mut rng).build(wtxn)?;
Ok(())
}
pub fn remove_embedding(&self, embedding_id: EmbeddingId) -> Result<(), arroy::Error> {
let dims = self.get_dim()?;
let mut wtxn = self.env.write_txn()?;
let mut writer = Writer::<DotProduct>::new(self.database, 0, dims);
writer.del_item(&mut wtxn, embedding_id.0)?;
self.recycle_id(embedding_id, &mut wtxn)?;
self.rebuild(&mut writer, &mut wtxn)?;
wtxn.commit()?;
Ok(())
}
pub fn add_embedding(&self, embedding: Embedding) -> Result<EmbeddingId, VectorDbError> {
let embedding = embedding.vector();
self.set_dim(embedding.len());
let mut wtxn = self.env.write_txn()?;
let mut writer = Writer::<DotProduct>::new(self.database, 0, embedding.len());
let id = self.take_id(&mut wtxn)?;
writer.add_item(&mut wtxn, id.0, embedding)?;
self.rebuild(&mut writer, &mut wtxn)?;
wtxn.commit()?;
Ok(id)
}
pub fn add_embeddings(
&self,
embedding: impl IntoIterator<Item = Embedding>,
) -> Result<Vec<EmbeddingId>, VectorDbError> {
let mut embeddings = embedding
.into_iter()
.map(|e| e.vector().to_vec().into_boxed_slice());
let Some(first_embedding) = embeddings.next() else {
return Ok(Vec::new());
};
self.set_dim(first_embedding.len());
let mut wtxn = self.env.write_txn()?;
let mut writer = Writer::<DotProduct>::new(self.database, 0, first_embedding.len());
let mut ids: Vec<_> = Vec::with_capacity(embeddings.size_hint().0 + 1);
{
let first_id = self.take_id(&mut wtxn)?;
writer.add_item(&mut wtxn, first_id.0, &first_embedding)?;
ids.push(first_id);
}
for embedding in embeddings {
let id = self.take_id(&mut wtxn)?;
writer.add_item(&mut wtxn, id.0, &embedding)?;
ids.push(id);
}
self.rebuild(&mut writer, &mut wtxn)?;
wtxn.commit()?;
Ok(ids)
}
pub fn get_embedding(&self, embedding_id: EmbeddingId) -> Result<Embedding, VectorDbError> {
let rtxn = self.env.read_txn()?;
let reader = Reader::<DotProduct>::open(&rtxn, 0, self.database)?;
let embedding = reader
.item_vector(&rtxn, embedding_id.0)?
.ok_or_else(|| VectorDbError::EmbeddingNotFound(embedding_id))?;
Ok(Embedding::from(embedding))
}
pub fn search<'a>(&'a self, embedding: &'a Embedding) -> VectorDBSearchBuilder<'a> {
VectorDBSearchBuilder {
db: self,
embedding,
results: None,
filter: None,
}
}
}
pub trait IntoVectorDbSearchFilter<M> {
fn into_vector_db_search_filter(self, db: &VectorDB) -> Candidates;
}
impl IntoVectorDbSearchFilter<()> for Candidates {
fn into_vector_db_search_filter(self, _: &VectorDB) -> Candidates {
self
}
}
pub struct IteratorMarker;
impl<I> IntoVectorDbSearchFilter<IteratorMarker> for I
where
I: IntoIterator<Item = EmbeddingId>,
{
fn into_vector_db_search_filter(self, _: &VectorDB) -> Candidates {
let mut candidates = Candidates::new();
for id in self {
candidates.insert(id.0);
}
candidates
}
}
pub struct ClosureMarker;
impl<I> IntoVectorDbSearchFilter<ClosureMarker> for I
where
I: FnMut(Embedding) -> bool,
{
fn into_vector_db_search_filter(mut self, db: &VectorDB) -> Candidates {
let mut candidates = Candidates::new();
let rtxn = match db.env.read_txn() {
Ok(rtxn) => rtxn,
Err(err) => {
tracing::error!("Error opening read transaction: {:?}", err);
return candidates;
}
};
let reader = match Reader::<DotProduct>::open(&rtxn, 0, db.database) {
Ok(reader) => reader,
Err(err) => {
tracing::error!("Error opening reader: {:?}", err);
return candidates;
}
};
for (key, tensor) in reader.iter(&rtxn).ok().into_iter().flatten().flatten() {
let embedding = Embedding::from(tensor);
if self(embedding) {
candidates.insert(key);
}
}
candidates
}
}
pub struct VectorDBSearchBuilder<'a> {
db: &'a VectorDB,
embedding: &'a Embedding,
results: Option<usize>,
filter: Option<Candidates>,
}
impl VectorDBSearchBuilder<'_> {
pub fn with_results(mut self, results: usize) -> Self {
self.results = Some(results);
self
}
pub fn with_filter<Marker>(
mut self,
filter: impl IntoVectorDbSearchFilter<Marker> + Send + Sync + 'static,
) -> Self {
self.filter = Some(filter.into_vector_db_search_filter(self.db));
self
}
pub fn run(self) -> Result<Vec<VectorDBSearchResult>, VectorDbError> {
let rtxn = self.db.env.read_txn()?;
let reader = Reader::<DotProduct>::open(&rtxn, 0, self.db.database)?;
let vector = self.embedding.vector();
let mut query = reader.nns(self.results.unwrap_or(10));
if let Some(filter) = self.filter.as_ref() {
query.candidates(filter);
}
let arroy_results = query.by_vector(&rtxn, vector)?;
Ok(arroy_results
.into_iter()
.map(|(id, distance)| {
let value = EmbeddingId(id);
VectorDBSearchResult { distance, value }
})
.collect::<Vec<_>>())
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct VectorDBSearchResult {
pub distance: f32,
pub value: EmbeddingId,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
pub struct EmbeddingId(pub u32);
#[tokio::test]
async fn test_vector_db_get_closest() {
let db: VectorDB = VectorDB::new().unwrap();
let first_vector = Embedding::from([1.0, 2.0, 3.0]);
let second_embedding = Embedding::from([-1.0, 2.0, 3.0]);
let id1 = db.add_embedding(first_vector.clone()).unwrap();
let id2 = db.add_embedding(second_embedding.clone()).unwrap();
assert_eq!(
db.search(&first_vector)
.with_results(1)
.run()
.unwrap()
.iter()
.map(|r| r.value)
.collect::<Vec<_>>(),
vec![id1]
);
assert_eq!(
db.search(&second_embedding)
.with_results(1)
.run()
.unwrap()
.iter()
.map(|r| r.value)
.collect::<Vec<_>>(),
vec![id2]
);
let third_embedding = Embedding::from([1.0, 0.0, 0.0]);
assert_eq!(
db.search(&third_embedding)
.with_results(1)
.run()
.unwrap()
.iter()
.map(|r| r.value)
.collect::<Vec<_>>(),
vec![id1]
);
assert_eq!(
db.search(&third_embedding)
.with_filter(|vector: Embedding| vector.vector()[0] < 0.0)
.run()
.unwrap()
.iter()
.map(|r| r.value)
.collect::<Vec<_>>(),
vec![id2]
);
}