use kalosm_model_types::{FileLoadingProgress, FileSource};
use std::path::PathBuf;
#[cfg(feature = "tokio")]
use tokio::fs::{File, OpenOptions};
#[cfg(feature = "tokio")]
use tokio::io::AsyncWriteExt;
#[derive(Debug, thiserror::Error)]
#[non_exhaustive]
pub enum CacheError {
#[cfg(feature = "tokio")]
#[error("Hugging Face API error: {0}")]
HuggingFaceApi(#[from] hf_hub::api::tokio::ApiError),
#[cfg(feature = "tokio")]
#[error("Unable to get file metadata for {0}: {1}")]
UnableToGetFileMetadata(PathBuf, #[source] tokio::io::Error),
#[error("IO error: {0}")]
Io(#[from] std::io::Error),
#[error("HTTP error: {0}")]
Http(#[from] reqwest::Error),
#[error("Unexpected status code: {0}")]
UnexpectedStatusCode(reqwest::StatusCode),
#[cfg(target_arch = "wasm32")]
#[error("OPFS not available: {0}")]
OpfsNotAvailable(String),
#[cfg(target_arch = "wasm32")]
#[error("OPFS operation failed: {0}")]
OpfsError(String),
}
#[derive(Debug, Clone)]
pub struct Cache {
location: PathBuf,
huggingface_token: Option<String>,
}
impl Cache {
pub fn new(location: PathBuf) -> Self {
Self {
location,
huggingface_token: None,
}
}
pub fn with_huggingface_token(mut self, token: Option<String>) -> Self {
self.huggingface_token = token;
self
}
pub fn exists(&self, source: &FileSource) -> bool {
match source {
FileSource::HuggingFace {
model_id,
revision,
file,
..
} => {
let path = self.location.join(model_id).join(revision);
let complete_download = path.join(file);
complete_download.exists()
}
FileSource::Local(path) => path.exists(),
}
}
pub async fn get_bytes(
&self,
source: &FileSource,
#[allow(unused_mut)] mut progress: impl FnMut(FileLoadingProgress),
) -> Result<Vec<u8>, CacheError> {
#[cfg(feature = "tokio")]
{
let path = self.get(source, progress).await;
tokio::fs::read(path?).await.map_err(CacheError::from)
}
#[cfg(not(feature = "tokio"))]
{
#[cfg(target_arch = "wasm32")]
{
use crate::opfs::is_opfs_available;
if is_opfs_available().await {
match self.get_bytes_opfs(source, &mut progress).await {
Ok(bytes) => return Ok(bytes),
Err(e) => {
tracing::warn!("OPFS caching failed, falling back to in-memory: {}", e);
}
}
}
self.get_bytes_memory(source, progress).await
}
#[cfg(not(target_arch = "wasm32"))]
{
self.get_bytes_memory(source, progress).await
}
}
}
#[cfg(feature = "tokio")]
pub async fn get(
&self,
source: &FileSource,
progress: impl FnMut(FileLoadingProgress),
) -> Result<PathBuf, CacheError> {
use hf_hub::{Repo, RepoType};
match source {
FileSource::HuggingFace {
model_id,
revision,
file,
} => {
let token = self.huggingface_token.clone().or_else(huggingface_token);
let repo = Repo::with_revision(
model_id.to_string(),
RepoType::Model,
revision.to_string(),
);
let api = hf_hub::api::tokio::Api::new()?.repo(repo);
let url = api.url(file);
let client = reqwest::Client::new();
tracing::trace!("Fetching metadata for {file} from {url}");
let response = client
.head(&url)
.with_authorization_header(token.clone())
.send()
.await;
let path = self.location.join(model_id).join(revision);
let complete_download = path.join(file);
if is_file_current(&complete_download, &response).await {
return Ok(complete_download);
}
let lock_path = path.join(format!("{file}.lock"));
tokio::fs::create_dir_all(&path).await?;
let lock_path_clone = lock_path.clone();
let _lock_file = tokio::task::spawn_blocking(move || {
let file = std::fs::File::create(&lock_path_clone)?;
file.lock()?;
Ok::<_, std::io::Error>(file)
})
.await
.map_err(|e| CacheError::Io(std::io::Error::other(e)))??;
if is_file_current(&complete_download, &response).await {
let _ = tokio::fs::remove_file(&lock_path).await;
return Ok(complete_download);
}
let incomplete_download = path.join(format!("{file}.partial"));
tracing::trace!("Downloading into {:?}", incomplete_download);
let download_result = download_into(
url,
&incomplete_download,
response?,
client,
token,
progress,
)
.await;
if let Err(e) = download_result {
let _ = tokio::fs::remove_file(&lock_path).await;
return Err(e);
}
let rename_result =
tokio::fs::rename(&incomplete_download, &complete_download).await;
let _ = tokio::fs::remove_file(&lock_path).await;
rename_result?;
Ok(complete_download)
}
FileSource::Local(path) => Ok(path.clone()),
}
}
#[cfg(all(not(feature = "tokio"), target_arch = "wasm32"))]
async fn get_bytes_opfs(
&self,
source: &FileSource,
progress: &mut impl FnMut(FileLoadingProgress),
) -> Result<Vec<u8>, CacheError> {
use crate::opfs::{
close_writable_stream, seek_writable_stream, write_chunk_to_stream, OpfsCache,
};
use futures_util::StreamExt;
use reqwest::StatusCode;
match source {
FileSource::HuggingFace {
model_id,
revision,
file,
} => {
let opfs = OpfsCache::new().await?;
let cache_dir = opfs
.get_directory(&["kalosm", "cache", model_id, revision])
.await?;
let safe_file = crate::opfs::sanitize_name(file);
let token = self.huggingface_token.clone().or_else(huggingface_token);
let url = format!("https://huggingface.co/{model_id}/resolve/{revision}/{file}");
let client = reqwest::Client::new();
let head_response = client
.head(&url)
.with_authorization_header(token.clone())
.send()
.await?;
let expected_size = head_response
.headers()
.get(reqwest::header::CONTENT_LENGTH)
.and_then(|h| h.to_str().ok())
.and_then(|s| s.parse::<u64>().ok());
let local_size = opfs
.get_file_size(&cache_dir, &safe_file)
.await
.unwrap_or(0);
if let Some(expected) = expected_size {
if local_size == expected {
let bytes = opfs.read_file(&cache_dir, &safe_file).await?;
progress(FileLoadingProgress {
progress: local_size,
cached_size: local_size,
size: local_size,
start_time: None,
});
return Ok(bytes);
} else if local_size > expected {
let _ = opfs.delete_file(&cache_dir, &safe_file).await;
}
}
let start_offset = if expected_size.map_or(false, |e| local_size > e) {
0
} else {
local_size
};
let final_url = head_response.url().clone();
let mut request = client.get(final_url.clone());
if start_offset > 0 {
request =
request.header(reqwest::header::RANGE, format!("bytes={}-", start_offset));
}
let mut response = request.send().await?;
let mut status = response.status();
if status == StatusCode::RANGE_NOT_SATISFIABLE {
if local_size > 0 {
let bytes = opfs.read_file(&cache_dir, &safe_file).await?;
progress(FileLoadingProgress {
progress: local_size,
cached_size: local_size,
size: local_size,
start_time: None,
});
return Ok(bytes);
}
response = client
.get(final_url)
.with_authorization_header(token.clone())
.send()
.await?;
status = response.status();
}
let (total_size, resuming) = if status == StatusCode::PARTIAL_CONTENT {
let remaining = response
.headers()
.get(reqwest::header::CONTENT_LENGTH)
.and_then(|h| h.to_str().ok())
.and_then(|s| s.parse::<u64>().ok());
(remaining.map(|r| r + start_offset).or(expected_size), true)
} else if status == StatusCode::OK {
let total = response
.headers()
.get(reqwest::header::CONTENT_LENGTH)
.and_then(|h| h.to_str().ok())
.and_then(|s| s.parse::<u64>().ok());
(total.or(expected_size), false)
} else {
return Err(CacheError::UnexpectedStatusCode(status));
};
let actual_start = if resuming { start_offset } else { 0 };
if let Some(size) = total_size {
progress(FileLoadingProgress {
progress: actual_start,
cached_size: actual_start,
size,
start_time: None,
});
if actual_start == size {
return opfs.read_file(&cache_dir, &safe_file).await;
}
}
let mut all_bytes = if resuming && actual_start > 0 {
match opfs.read_file(&cache_dir, &safe_file).await {
Ok(existing) => existing,
Err(e) => {
tracing::warn!(
"[OPFS] Can't read existing file for resume: {}, deleting",
e
);
let _ = opfs.delete_file(&cache_dir, &safe_file).await;
return Err(CacheError::OpfsError(format!(
"Failed to read partial download for resume: {}. File deleted, please retry.",
e
)));
}
}
} else {
Vec::new()
};
let mut writable = opfs
.create_writable(&cache_dir, &safe_file, resuming)
.await?;
if resuming && actual_start > 0 {
seek_writable_stream(&writable, actual_start).await?;
}
let mut current_progress = actual_start;
let mut stream = response.bytes_stream();
const FLUSH_INTERVAL: u64 = 100 * 1024 * 1024;
let mut bytes_since_flush: u64 = 0;
while let Some(chunk_result) = stream.next().await {
let chunk = chunk_result?;
write_chunk_to_stream(&writable, &chunk).await?;
all_bytes.extend_from_slice(&chunk);
current_progress += chunk.len() as u64;
bytes_since_flush += chunk.len() as u64;
if let Some(size) = total_size {
progress(FileLoadingProgress {
progress: current_progress,
cached_size: actual_start,
size,
start_time: None,
});
}
if bytes_since_flush >= FLUSH_INTERVAL {
close_writable_stream(&writable).await?;
writable = opfs.create_writable(&cache_dir, &safe_file, true).await?;
seek_writable_stream(&writable, current_progress).await?;
bytes_since_flush = 0;
}
}
close_writable_stream(&writable).await?;
Ok(all_bytes)
}
FileSource::Local(_) => Err(CacheError::Io(std::io::Error::other(
"Local file access not supported on WASM",
))),
}
}
#[cfg(not(feature = "tokio"))]
async fn get_bytes_memory(
&self,
source: &FileSource,
mut progress: impl FnMut(FileLoadingProgress),
) -> Result<Vec<u8>, CacheError> {
use futures_util::StreamExt;
use reqwest::StatusCode;
use std::str::FromStr;
match source {
FileSource::HuggingFace {
model_id,
revision,
file,
} => {
let token = self.huggingface_token.clone().or_else(huggingface_token);
let url = format!("https://huggingface.co/{model_id}/resolve/{revision}/{file}");
let client = reqwest::Client::new();
let head = client
.head(&url)
.with_authorization_header(token.clone())
.send()
.await
.map_err(CacheError::from)?;
let length = head
.headers()
.get(reqwest::header::CONTENT_LENGTH)
.and_then(|length| length.to_str().ok().and_then(|s| u64::from_str(s).ok()));
if let Some(length) = length {
progress(FileLoadingProgress {
progress: 0,
cached_size: 0,
size: length,
start_time: None,
});
}
let request = client.get(url).with_authorization_header(token);
let response = request.send().await?;
let status = response.status();
if !(status == StatusCode::OK || status == StatusCode::PARTIAL_CONTENT) {
return Err(CacheError::UnexpectedStatusCode(status));
}
let mut current_progress = 0;
let mut bytes = Vec::new();
let mut stream = response.bytes_stream();
while let Some(chunk) = stream.next().await {
let chunk = chunk?;
bytes.extend_from_slice(&chunk);
tracing::trace!("wrote chunk of size {}", chunk.len());
current_progress += chunk.len() as u64;
if let Some(length) = length {
progress(FileLoadingProgress {
progress: current_progress,
cached_size: 0,
size: length,
start_time: None,
});
}
}
Ok(bytes)
}
_ => Err(CacheError::Io(std::io::Error::other(
"Local file access not supported without the 'tokio' feature",
))),
}
}
#[cfg(all(not(feature = "tokio"), target_arch = "wasm32"))]
pub async fn exists_async(&self, source: &FileSource) -> bool {
use crate::opfs::{is_opfs_available, sanitize_name, OpfsCache};
match source {
FileSource::HuggingFace {
model_id,
revision,
file,
..
} => {
if !is_opfs_available().await {
return false;
}
let Ok(opfs) = OpfsCache::new().await else {
return false;
};
let Ok(cache_dir) = opfs
.get_directory(&["kalosm", "cache", model_id, revision])
.await
else {
return false;
};
let safe_file = sanitize_name(file);
opfs.get_file_size(&cache_dir, &safe_file)
.await
.map_or(false, |size| size > 0)
}
FileSource::Local(_) => false,
}
}
}
#[allow(clippy::derivable_impls)]
impl Default for Cache {
fn default() -> Self {
Self {
location: {
#[cfg(feature = "tokio")]
{
dirs::data_dir()
.or_else(dirs::cache_dir)
.or_else(|| std::env::var("HOME").ok().map(std::path::PathBuf::from))
.or_else(|| std::env::current_dir().ok())
.unwrap_or_else(|| std::path::PathBuf::from("/tmp"))
.join("kalosm")
.join("cache")
}
#[cfg(not(feature = "tokio"))]
{
Default::default()
}
},
huggingface_token: None,
}
}
}
#[cfg(feature = "tokio")]
async fn is_file_current(
path: &std::path::Path,
response: &Result<reqwest::Response, reqwest::Error>,
) -> bool {
if !path.exists() {
return false;
}
let Ok(metadata) = tokio::fs::metadata(path).await else {
return false;
};
let Ok(file_last_modified) = metadata.modified() else {
return false;
};
if let Some(last_updated) = response
.as_ref()
.ok()
.and_then(|response| response.headers().get(reqwest::header::LAST_MODIFIED))
.and_then(|last_updated| last_updated.to_str().ok())
.and_then(|s| httpdate::parse_http_date(s).ok())
{
last_updated <= file_last_modified
} else {
true
}
}
#[cfg(feature = "tokio")]
async fn download_into<U: reqwest::IntoUrl>(
url: U,
file: &PathBuf,
head: reqwest::Response,
client: reqwest::Client,
token: Option<String>,
mut progress: impl FnMut(FileLoadingProgress),
) -> Result<(), CacheError> {
use reqwest::header::{HeaderValue, CONTENT_LENGTH, RANGE};
use reqwest::StatusCode;
use std::str::FromStr;
let length = head
.headers()
.get(CONTENT_LENGTH)
.and_then(|v| v.to_str().ok())
.and_then(|s| u64::from_str(s).ok());
let (start, mut output_file) = if let Ok(metadata) = tokio::fs::metadata(file).await {
let start = metadata.len();
let output_file = OpenOptions::new().append(true).open(file).await?;
(start, output_file)
} else {
if let Some(parent) = file.parent() {
tokio::fs::create_dir_all(parent).await?;
}
(0, File::create(file).await?)
};
if let Some(length) = length {
progress(FileLoadingProgress {
progress: start,
cached_size: start,
size: length,
start_time: Some(std::time::Instant::now()),
});
}
if Some(start) == length {
tracing::trace!("File {} already downloaded", file.display());
progress(FileLoadingProgress {
progress: start,
cached_size: start,
size: length.unwrap_or(0),
start_time: Some(std::time::Instant::now()),
});
return Ok(());
}
let range = length
.and_then(|length| HeaderValue::from_str(&format!("bytes={}-{}", start, length - 1)).ok());
tracing::trace!("Fetching range {:?}", range);
let mut request = client.get(url).with_authorization_header(token);
if let Some(range) = range {
request = request.header(RANGE, range);
}
let mut response = request.send().await?;
let status = response.status();
if !(status == StatusCode::OK || status == StatusCode::PARTIAL_CONTENT) {
return Err(CacheError::UnexpectedStatusCode(status));
}
let mut current_progress = start;
while let Some(chunk) = response.chunk().await? {
output_file.write_all(&chunk).await?;
tracing::trace!("wrote chunk of size {}", chunk.len());
current_progress += chunk.len() as u64;
if let Some(length) = length {
progress(FileLoadingProgress {
progress: current_progress,
cached_size: start,
size: length,
start_time: Some(std::time::Instant::now()),
});
}
}
tracing::trace!("Download of {} complete", file.display());
Ok(())
}
trait RequestBuilderExt {
fn with_authorization_header(self, token: Option<String>) -> Self;
}
impl RequestBuilderExt for reqwest::RequestBuilder {
fn with_authorization_header(self, token: Option<String>) -> Self {
if let Some(token) = token {
self.header(reqwest::header::AUTHORIZATION, format!("Bearer {token}"))
} else {
self
}
}
}
#[cfg(test)]
#[tokio::test]
async fn downloads_work() {
let url = "https://httpbin.org/range/102400?duration=2";
let file = PathBuf::from("download.bin");
let progress = |p| {
println!("Progress: {p:?}");
};
let client = reqwest::Client::new();
let response = client.head(url).send().await.unwrap();
download_into(url, &file, response, client, None, progress)
.await
.unwrap();
assert!(file.exists());
tokio::fs::remove_file(file).await.unwrap();
}
fn huggingface_token() -> Option<String> {
cfg!(not(target_arch = "wasm32"))
.then(|| {
let cache = hf_hub::Cache::default();
cache.token().or_else(|| std::env::var("HF_TOKEN").ok())
})
.flatten()
}