use std::{
fmt::Display,
path::PathBuf,
sync::{
Arc, OnceLock,
atomic::{AtomicUsize, Ordering},
},
};
use clap::ValueEnum;
use ek_base::config::get_ek_settings;
use tokio::sync::{
Mutex,
mpsc::{Receiver, Sender},
};
use super::backend::Device;
static INSTANCE_COUNTER: AtomicUsize = AtomicUsize::new(0);
#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, ValueEnum, Debug)]
pub enum ExpertBackendType {
Torch,
Onnx,
Ggml,
}
impl TryFrom<&str> for ExpertBackendType {
type Error = &'static str;
fn try_from(value: &str) -> Result<Self, Self::Error> {
match value {
"torch" => Ok(ExpertBackendType::Torch),
"ort" => Ok(ExpertBackendType::Onnx),
"ggml" => Ok(ExpertBackendType::Ggml),
_ => Err("Unknown backend type"),
}
}
}
impl Display for ExpertBackendType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ExpertBackendType::Torch => write!(f, "torch"),
ExpertBackendType::Onnx => write!(f, "onnx"),
ExpertBackendType::Ggml => write!(f, "ggml"),
}
}
}
#[derive(Clone, Copy)]
pub struct EKInstance {
pub hidden: usize,
pub intermediate: usize,
pub backend: ExpertBackendType,
pub device: Device,
}
impl Default for EKInstance {
fn default() -> Self {
let settings = get_ek_settings();
let _ = INSTANCE_COUNTER.fetch_add(1, Ordering::SeqCst);
let device = Device::from(settings.worker.device.as_str());
Self {
hidden: settings.inference.hidden_dim,
intermediate: settings.inference.intermediate_dim,
backend: ExpertBackendType::try_from(settings.worker.backend.as_str()).unwrap(),
device,
}
}
}
pub fn test_root() -> PathBuf {
let root = env!("CARGO_MANIFEST_DIR");
PathBuf::from(root.to_owned())
}
type GracefulChannelPair = (Sender<()>, Arc<Mutex<Receiver<()>>>);
pub fn get_graceful_shutdown_ch() -> GracefulChannelPair {
static GRACEFUL_SHUTDOWN: OnceLock<GracefulChannelPair> = OnceLock::new();
let res = GRACEFUL_SHUTDOWN.get_or_init(|| {
let (tx, rx) = tokio::sync::mpsc::channel(1);
(tx, Arc::new(Mutex::new(rx)))
});
(res.0.clone(), res.1.clone())
}
#[cfg(test)]
mod test {
use tch::Cuda;
#[test]
fn test_env() {
println!("CUDA Device Count: {}", Cuda::device_count());
println!("CUDA available: {}", Cuda::is_available());
}
#[test]
fn test_force_cuda() {
let _ = tch::Tensor::zeros([1, 2], (tch::Kind::Float, tch::Device::Cuda(0)));
println!("Tensor on CUDA successfully created.");
}
}