use std::{fmt::Debug, sync::Arc};
use crate::{Device, QMatrix};
pub use fusor_gguf::{GgufMetadata, GgufReadError, GgufValue};
trait ReadAndSeek: std::io::Read + std::io::Seek {}
impl<T: std::io::Read + std::io::Seek + ?Sized> ReadAndSeek for T {}
fn tensor_byte_size(ty: fusor_gguf::GgmlType, num_elements: usize) -> usize {
let block_size = ty.block_size();
let num_blocks = num_elements / block_size;
num_blocks * ty.block_allocation_size()
}
pub struct VarBuilder<'a> {
reader: &'a mut dyn ReadAndSeek,
metadata: Arc<GgufMetadata>,
path: Vec<String>,
}
impl<'a> Debug for VarBuilder<'a> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("VarBuilder")
.field("path", &self.path)
.field("metadata", &self.metadata)
.finish()
}
}
impl<'a> VarBuilder<'a> {
pub fn from_gguf<R: std::io::Read + std::io::Seek>(
reader: &'a mut R,
) -> Result<Self, GgufReadError> {
let metadata = GgufMetadata::read(&mut *reader)?.into();
let path = Default::default();
Ok(Self {
reader,
metadata,
path,
})
}
pub fn pp<'b, S: ToString>(&'b mut self, s: S) -> VarBuilder<'b> {
let mut new_path = self.path.clone();
new_path.push(s.to_string());
VarBuilder {
reader: &mut *self.reader,
metadata: self.metadata.clone(),
path: new_path,
}
}
fn format_path(&self, name: &str) -> String {
let mut full_path = self.path.join(".");
if !full_path.is_empty() {
full_path.push('.');
}
full_path.push_str(name);
full_path
}
pub fn get(&mut self, key: &str, device: &Device) -> crate::Result<QMatrix> {
let full_path = self.format_path(key);
let q_matrix_metadata = self.metadata.tensor_infos.get(&*full_path).ok_or_else(|| {
crate::Error::VarBuilder(format!("Key '{}' not found in GGUF metadata", full_path))
})?;
let tensor_info = q_matrix_metadata;
let offset = self.metadata.tensor_data_offset + tensor_info.offset;
self.reader
.seek(std::io::SeekFrom::Start(offset))
.map_err(|e| {
crate::Error::VarBuilder(format!("Failed to seek to tensor data: {}", e))
})?;
let ggml_type = tensor_info.ty;
let shape: Box<[usize]> = tensor_info
.shape
.iter()
.map(|&d| d as usize)
.collect::<Vec<_>>()
.into_boxed_slice();
let num_elements: usize = shape.iter().product();
let byte_size = tensor_byte_size(ggml_type, num_elements);
let mut bytes = vec![0u8; byte_size];
self.reader
.read_exact(&mut bytes)
.map_err(|e| crate::Error::VarBuilder(format!("Failed to read tensor data: {}", e)))?;
QMatrix::from_raw_bytes(device, shape, &bytes, ggml_type)
.map_err(|e| crate::Error::VarBuilder(format!("Failed to create QMatrix: {}", e)))
}
pub fn contains_key(&self, key: &str) -> bool {
let full_path = self.format_path(key);
self.metadata.tensor_infos.contains_key(&*full_path)
}
pub fn list_all_keys(&self) -> Vec<String> {
self.metadata
.tensor_infos
.keys()
.map(|k| k.to_string())
.collect()
}
pub fn metadata(&self) -> &GgufMetadata {
&self.metadata
}
pub fn get_metadata(&self, key: &str) -> Option<&GgufValue> {
self.metadata.get_value(key)
}
pub fn architecture(&self) -> Option<String> {
self.metadata.architecture()
}
}
pub struct ShardedVarBuilder<R: std::io::Read + std::io::Seek> {
contents: Vec<(GgufMetadata, R)>,
}
impl<R: std::io::Read + std::io::Seek> ShardedVarBuilder<R> {
pub fn new(contents: Vec<(GgufMetadata, R)>) -> Self {
Self { contents }
}
pub fn get(&self, name: &str) -> crate::Result<&GgufValue> {
for (content, _) in &self.contents {
if let Some(value) = content.get_value(name) {
return Ok(value);
}
}
Err(crate::Error::VarBuilder(format!(
"Key '{}' not found in GGUF metadata",
name
)))
}
pub fn tensor(&mut self, name: &str, device: &Device) -> crate::Result<QMatrix> {
for (content, r) in &mut self.contents {
if let Some(tensor_info) = content.tensor_infos.get(name) {
let offset = content.tensor_data_offset + tensor_info.offset;
r.seek(std::io::SeekFrom::Start(offset))
.map_err(|e| crate::Error::VarBuilder(format!("Failed to seek: {}", e)))?;
let ggml_type = tensor_info.ty;
let shape: Box<[usize]> = tensor_info
.shape
.iter()
.map(|&d| d as usize)
.collect::<Vec<_>>()
.into_boxed_slice();
let num_elements: usize = shape.iter().product();
let byte_size = tensor_byte_size(ggml_type, num_elements);
let mut bytes = vec![0u8; byte_size];
r.read_exact(&mut bytes)
.map_err(|e| crate::Error::VarBuilder(format!("Failed to read: {}", e)))?;
return QMatrix::from_raw_bytes(device, shape, &bytes, ggml_type).map_err(|e| {
crate::Error::VarBuilder(format!("Failed to create QMatrix: {}", e))
});
}
}
Err(crate::Error::VarBuilder(format!(
"Key '{}' not found in GGUF metadata",
name
)))
}
}