#![warn(missing_docs)]
use std::{sync::OnceLock, time::Duration};
use futures_channel::mpsc::{UnboundedReceiver, UnboundedSender};
use futures_util::{Stream, StreamExt};
use image::ImageBuffer;
use kalosm_common::{Cache, CacheError};
pub use kalosm_model_types::ModelLoadingProgress;
use kalosm_model_types::{FileSource, ModelBuilder};
use model::{WuerstcheModelSettings, WuerstchenInner};
mod model;
static ZERO_IMAGE: OnceLock<ImageBuffer<image::Rgb<u8>, Vec<u8>>> = OnceLock::new();
#[derive(Debug, Clone)]
struct DiffusionResult {
image: ImageBuffer<image::Rgb<u8>, Vec<u8>>,
height: usize,
width: usize,
}
#[derive(Debug)]
pub struct Image {
sample_num: i64,
elapsed_time: Duration,
remaining_time: Duration,
progress: f32,
result: candle_core::Result<DiffusionResult>,
}
impl Image {
pub fn sample_num(&self) -> i64 {
self.sample_num
}
pub fn elapsed_time(&self) -> Duration {
self.elapsed_time
}
pub fn remaining_time(&self) -> Duration {
self.remaining_time
}
pub fn progress(&self) -> f32 {
self.progress
}
pub fn height(&self) -> Option<usize> {
self.result.as_ref().ok().map(|val| val.height)
}
pub fn width(&self) -> Option<usize> {
self.result.as_ref().ok().map(|val| val.width)
}
pub fn generated_image(&self) -> Option<ImageBuffer<image::Rgb<u8>, Vec<u8>>> {
self.result.as_ref().ok().map(|val| val.image.clone())
}
pub fn error(&self) -> Option<&candle_core::Error> {
self.result.as_ref().err()
}
}
impl AsRef<ImageBuffer<image::Rgb<u8>, Vec<u8>>> for Image {
fn as_ref(&self) -> &ImageBuffer<image::Rgb<u8>, Vec<u8>> {
match &self.result {
Ok(val) => &val.image,
Err(_) => ZERO_IMAGE.get_or_init(|| ImageBuffer::new(0, 0)),
}
}
}
#[derive(Default)]
pub struct WuerstchenBuilder {
use_flash_attn: bool,
decoder_weights: Option<String>,
clip_weights: Option<String>,
prior_clip_weights: Option<String>,
prior_weights: Option<String>,
vqgan_weights: Option<String>,
tokenizer: Option<String>,
prior_tokenizer: Option<String>,
cache: Cache,
}
impl WuerstchenBuilder {
pub fn with_flash_attn(mut self, use_flash_attn: bool) -> Self {
self.use_flash_attn = use_flash_attn;
self
}
pub fn with_decoder_weights(mut self, decoder_weights: impl Into<String>) -> Self {
self.decoder_weights = Some(decoder_weights.into());
self
}
pub fn with_clip_weights(mut self, clip_weights: impl Into<String>) -> Self {
self.clip_weights = Some(clip_weights.into());
self
}
pub fn with_prior_clip_weights(mut self, prior_clip_weights: impl Into<String>) -> Self {
self.prior_clip_weights = Some(prior_clip_weights.into());
self
}
pub fn with_prior_weights(mut self, prior_weights: impl Into<String>) -> Self {
self.prior_weights = Some(prior_weights.into());
self
}
pub fn with_vqgan_weights(mut self, vqgan_weights: impl Into<String>) -> Self {
self.vqgan_weights = Some(vqgan_weights.into());
self
}
pub fn with_tokenizer(mut self, tokenizer: impl Into<String>) -> Self {
self.tokenizer = Some(tokenizer.into());
self
}
pub fn with_prior_tokenizer(mut self, prior_tokenizer: impl Into<String>) -> Self {
self.prior_tokenizer = Some(prior_tokenizer.into());
self
}
pub fn with_cache(mut self, cache: kalosm_common::Cache) -> Self {
self.cache = cache;
self
}
pub async fn build(self) -> Result<Wuerstchen, CacheError> {
self.build_with_loading_handler(ModelLoadingProgress::multi_bar_loading_indicator())
.await
}
pub async fn build_with_loading_handler(
self,
mut progress_handler: impl FnMut(ModelLoadingProgress) + Send + Sync + 'static,
) -> Result<Wuerstchen, CacheError> {
let WuerstchenBuilder {
use_flash_attn,
decoder_weights,
clip_weights,
prior_clip_weights,
prior_weights,
vqgan_weights,
tokenizer,
prior_tokenizer,
cache,
} = self;
let prior_tokenizer_source = ModelFile::PriorTokenizer.get(prior_tokenizer);
let prior_tokenizer_source_display = format!("Prior Tokenizer ({prior_tokenizer_source})");
let mut create_progress =
ModelLoadingProgress::downloading_progress(prior_tokenizer_source_display);
let prior_tokenizer = cache
.get(&prior_tokenizer_source, |progress| {
progress_handler(create_progress(progress))
})
.await?;
let tokenizer_source = ModelFile::Tokenizer.get(tokenizer);
let tokenizer_source_display = format!("Tokenizer ({tokenizer_source})");
let mut create_progress =
ModelLoadingProgress::downloading_progress(tokenizer_source_display);
let tokenizer = cache
.get(&tokenizer_source, |progress| {
progress_handler(create_progress(progress))
})
.await?;
let clip_weights_source = ModelFile::Clip.get(clip_weights);
let clip_weights_source_display = format!("Clip Weights ({clip_weights_source})");
let mut create_progress =
ModelLoadingProgress::downloading_progress(clip_weights_source_display);
let clip_weights = cache
.get(&clip_weights_source, |progress| {
progress_handler(create_progress(progress))
})
.await?;
let prior_clip_weights_source = ModelFile::PriorClip.get(prior_clip_weights);
let prior_clip_weights_source_display =
format!("Prior Clip Weights ({prior_clip_weights_source})");
let mut create_progress =
ModelLoadingProgress::downloading_progress(prior_clip_weights_source_display);
let prior_clip_weights = cache
.get(&prior_clip_weights_source, |progress| {
progress_handler(create_progress(progress))
})
.await?;
let decoder_weights_source = ModelFile::Decoder.get(decoder_weights);
let decoder_weights_source_display = format!("Decoder Weights ({decoder_weights_source})");
let mut create_progress =
ModelLoadingProgress::downloading_progress(decoder_weights_source_display);
let decoder_weights = cache
.get(&decoder_weights_source, |progress| {
progress_handler(create_progress(progress))
})
.await?;
let prior_weights_source = ModelFile::Prior.get(prior_weights);
let prior_weights_source_display = format!("Prior Weights ({prior_weights_source})");
let mut create_progress =
ModelLoadingProgress::downloading_progress(prior_weights_source_display);
let prior_weights = cache
.get(&prior_weights_source, |progress| {
progress_handler(create_progress(progress))
})
.await?;
let vqgan_weights_source = ModelFile::VqGan.get(vqgan_weights);
let vqgan_weights_source_display = format!("VQGAN Weights ({vqgan_weights_source})");
let mut create_progress =
ModelLoadingProgress::downloading_progress(vqgan_weights_source_display);
let vqgan_weights = cache
.get(&vqgan_weights_source, |progress| {
progress_handler(create_progress(progress))
})
.await?;
let settings = WuerstcheModelSettings {
use_flash_attn,
decoder_weights,
clip_weights,
prior_clip_weights,
prior_weights,
vqgan_weights,
tokenizer,
prior_tokenizer,
};
let model = WuerstchenInner::new(settings).unwrap();
let (rx, tx) = std::sync::mpsc::channel();
let thread = std::thread::spawn(move || {
while let Ok(message) = tx.recv() {
match message {
WuerstchenMessage::Kill => return,
WuerstchenMessage::Generate(input, result) => {
model.run(input, result);
}
}
}
});
Ok(Wuerstchen {
thread: Some(thread),
sender: rx,
})
}
}
impl ModelBuilder for WuerstchenBuilder {
type Model = Wuerstchen;
type Error = CacheError;
async fn start_with_loading_handler(
self,
handler: impl FnMut(ModelLoadingProgress) + Send + Sync + 'static,
) -> Result<Self::Model, Self::Error> {
self.build_with_loading_handler(handler).await
}
fn requires_download(&self) -> bool {
let cache = Cache::default();
let downloaded_decoder_weights = self.decoder_weights.is_none()
|| cache.exists(&<&ModelFile as Into<FileSource>>::into(&ModelFile::Decoder));
let downloaded_clip_weights = self.clip_weights.is_none()
|| cache.exists(&<&ModelFile as Into<FileSource>>::into(&ModelFile::Clip));
let downloaded_prior_clip_weights = self.prior_clip_weights.is_none()
|| cache.exists(&<&ModelFile as Into<FileSource>>::into(
&ModelFile::PriorClip,
));
let downloaded_prior_weights = self.prior_weights.is_none()
|| cache.exists(&<&ModelFile as Into<FileSource>>::into(&ModelFile::Prior));
let downloaded_vqgan_weights = self.vqgan_weights.is_none()
|| cache.exists(&<&ModelFile as Into<FileSource>>::into(&ModelFile::VqGan));
let downloaded_tokenizer = self.tokenizer.is_none()
|| cache.exists(&<&ModelFile as Into<FileSource>>::into(
&ModelFile::Tokenizer,
));
let downloaded_prior_tokenizer = self.prior_tokenizer.is_none()
|| cache.exists(&<&ModelFile as Into<FileSource>>::into(
&ModelFile::PriorTokenizer,
));
!(downloaded_decoder_weights
&& downloaded_clip_weights
&& downloaded_prior_clip_weights
&& downloaded_prior_weights
&& downloaded_vqgan_weights
&& downloaded_tokenizer
&& downloaded_prior_tokenizer)
}
}
pub struct Wuerstchen {
thread: Option<std::thread::JoinHandle<()>>,
sender: std::sync::mpsc::Sender<WuerstchenMessage>,
}
impl Wuerstchen {
pub async fn new() -> Result<Self, CacheError> {
Self::builder().build().await
}
pub fn builder() -> WuerstchenBuilder {
WuerstchenBuilder::default()
}
pub fn run(&self, settings: WuerstchenInferenceSettings) -> ChannelImageStream<Image> {
let (sender, receiver) = futures_channel::mpsc::unbounded();
self.run_into(settings, sender);
ChannelImageStream::from(receiver)
}
pub fn run_into(&self, settings: WuerstchenInferenceSettings, sender: UnboundedSender<Image>) {
_ = self
.sender
.send(WuerstchenMessage::Generate(settings, sender));
}
}
impl Drop for Wuerstchen {
fn drop(&mut self) {
self.sender.send(WuerstchenMessage::Kill).unwrap();
self.thread.take().unwrap().join().unwrap();
}
}
enum WuerstchenMessage {
Kill,
Generate(WuerstchenInferenceSettings, UnboundedSender<Image>),
}
pub struct WuerstchenInferenceSettings {
prompt: String,
uncond_prompt: String,
height: usize,
width: usize,
prior_steps: usize,
denoiser_steps: usize,
num_samples: i64,
prior_guidance_scale: f64,
}
impl WuerstchenInferenceSettings {
pub fn new(prompt: impl Into<String>) -> Self {
Self {
prompt: prompt.into(),
uncond_prompt: String::new(),
height: 1024,
width: 1024,
prior_steps: 60,
denoiser_steps: 12,
num_samples: 1,
prior_guidance_scale: 4.0,
}
}
pub fn with_negative_prompt(mut self, uncond_prompt: impl Into<String>) -> Self {
self.uncond_prompt = uncond_prompt.into();
self
}
pub fn with_height(mut self, height: usize) -> Self {
self.height = height;
self
}
pub fn with_width(mut self, width: usize) -> Self {
self.width = width;
self
}
pub fn with_prior_steps(mut self, prior_steps: usize) -> Self {
self.prior_steps = prior_steps;
self
}
pub fn with_denoiser_steps(mut self, denoiser_steps: usize) -> Self {
self.denoiser_steps = denoiser_steps;
self
}
pub fn with_sample_count(mut self, sample_count: i64) -> Self {
self.num_samples = sample_count;
self
}
pub fn with_prior_guidance_scale(mut self, prior_guidance_scale: f64) -> Self {
self.prior_guidance_scale = prior_guidance_scale;
self
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum ModelFile {
Tokenizer,
PriorTokenizer,
Clip,
PriorClip,
Decoder,
VqGan,
Prior,
}
impl ModelFile {
fn get(&self, filename: Option<String>) -> FileSource {
match filename {
Some(filename) => FileSource::local(std::path::PathBuf::from(filename)),
None => self.into(),
}
}
}
impl From<&ModelFile> for FileSource {
fn from(val: &ModelFile) -> Self {
let repo_main = "warp-ai/wuerstchen";
let repo_prior = "warp-ai/wuerstchen-prior";
let (repo, path) = match val {
ModelFile::Tokenizer => (repo_main, "tokenizer/tokenizer.json"),
ModelFile::PriorTokenizer => (repo_prior, "tokenizer/tokenizer.json"),
ModelFile::Clip => (repo_main, "text_encoder/model.safetensors"),
ModelFile::PriorClip => (repo_prior, "text_encoder/model.safetensors"),
ModelFile::Decoder => (repo_main, "decoder/diffusion_pytorch_model.safetensors"),
ModelFile::VqGan => (repo_main, "vqgan/diffusion_pytorch_model.safetensors"),
ModelFile::Prior => (repo_prior, "prior/diffusion_pytorch_model.safetensors"),
};
FileSource::huggingface(repo.to_owned(), "main".to_owned(), path.to_owned())
}
}
pub struct ChannelImageStream<S: AsRef<ImageBuffer<image::Rgb<u8>, Vec<u8>>>> {
receiver: UnboundedReceiver<S>,
}
impl<S: AsRef<ImageBuffer<image::Rgb<u8>, Vec<u8>>>> std::fmt::Debug for ChannelImageStream<S> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ChannelImageStream").finish()
}
}
impl<S: AsRef<ImageBuffer<image::Rgb<u8>, Vec<u8>>>> From<UnboundedReceiver<S>>
for ChannelImageStream<S>
{
fn from(receiver: UnboundedReceiver<S>) -> Self {
Self { receiver }
}
}
impl<S: AsRef<ImageBuffer<image::Rgb<u8>, Vec<u8>>>> Stream for ChannelImageStream<S> {
type Item = S;
fn poll_next(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> core::task::Poll<Option<Self::Item>> {
self.receiver.poll_next_unpin(cx)
}
}