/*
* Copyright (c) Huawei Technologies Co., Ltd. 2024-2025. All rights reserved.
*/
package magic.vdb
import std.core.LibC
import std.collection.ArrayList
@When[faiss == "enable"]
foreign {
func faiss_IndexFlatL2_new_with(indexPtrPtr: CPointer<CPointer<Unit>>, d:Int64): Int32
func faiss_Index_add(indexPtr: CPointer<Unit>, n: Int64, x: CPointer<Float32>): Int32
func faiss_write_index_fname(idxPtr: CPointer<Unit>, fname: CString): Int32
func faiss_read_index_fname(fname: CString, io_flags: Int32, p_out: CPointer<CPointer<Unit>>): Int32
func faiss_Index_free(indexPtr: CPointer<Unit>): Unit
func faiss_Index_search(indexPtr: CPointer<Unit>, n: Int64, x: CPointer<Float32>, k: Int64, distances: CPointer<Float32>, labels: CPointer<Int64>): Int32
// func faiss_Index_ntotal(indexPtr: CPointer<Unit>): Int64
}
@When[faiss == "enable"]
public class FaissVectorDatabase <: VectorDatabase<FaissVectorDatabase> {
private let faissIndexPtr: CPointer<CPointer<Unit>>
private init(faissIndexPtr: CPointer<CPointer<Unit>>) {
this.faissIndexPtr = faissIndexPtr
}
public init(dimension!: Int64 = 1536) {
unsafe {
faissIndexPtr = LibC.malloc<CPointer<Unit>>(count: 1)
let status = faiss_IndexFlatL2_new_with(faissIndexPtr, dimension)
if (status != 0) {
throw VectorDatabaseException("Failed to open database.")
}
}
}
public func close(): Unit {
unsafe {
faiss_Index_free(faissIndexPtr.read())
LibC.free(faissIndexPtr)
}
}
public override func save(filePath: String): Unit {
unsafe {
let cPath: CString = LibC.mallocCString(filePath)
let status = faiss_write_index_fname(faissIndexPtr.read(), cPath)
LibC.free(cPath)
if (status != 0) {
throw VectorDatabaseException("Failed to write faiss database.")
}
}
}
public static redef func load(filePath: String): FaissVectorDatabase {
let faissIndexPtr: CPointer<CPointer<Unit>>
unsafe {
faissIndexPtr = LibC.malloc<CPointer<Unit>>(count: 1)
let cPath: CString = LibC.mallocCString(filePath)
let ret = faiss_read_index_fname(cPath, 1, faissIndexPtr)
LibC.free(cPath)
if (ret != 0) {
throw Exception("Failed to open faiss database.")
}
// let ntotal = faiss_Index_ntotal(faissIndexPtr.read())
}
return FaissVectorDatabase(faissIndexPtr)
}
private func toFaissVector(source: Vector): VArray<Float32, $1536> {
var v = VArray<Float32, $1536>({_: Int64 => 0.0})
for (i in 0..source.vector.size) {
v[i] = Float32(source.vector[i])
}
return v
}
public override func addVector(vector: Vector): Unit {
var vec: VArray<Float32, $1536> = toFaissVector(vector)
let ret = unsafe {
faiss_Index_add(faissIndexPtr.read(), 1, inout vec)
}
// return Int64(ret)
}
public override func search(queryVec: Vector, number!: Int64 = 5, minDistance!: Float64 = 0.6): Array<SearchResult> {
var queryFaissVec = toFaissVector(queryVec)
let result = ArrayList<SearchResult>()
let N = 1 // number of querys
unsafe {
// Allocate memory
let distancePtr = LibC.malloc<Float32>(count: N * number)
let indexPtr = LibC.malloc<Int64>(count: N * number)
// Call the API
let status = faiss_Index_search(faissIndexPtr.read(), N, inout queryFaissVec, number, distancePtr, indexPtr)
if (status != 0) {
throw VectorDatabaseException("Failed to search faiss database.")
}
for (i in 0..number) {
let dist = distancePtr.read(i)
if (Float64(dist) < minDistance) {
break
}
let index = indexPtr.read(i)
result.append(SearchResult(index, dist))
}
LibC.free(distancePtr)
LibC.free(indexPtr)
}
return result.toArray()
}
}