//! Device abstraction for CPU and GPU
use crate::Error;
/// Represents a compute device (CPU or GPU).
#[derive(Clone, Debug, Default)]
pub enum Device {
/// CPU device - uses fusor-cpu for SIMD-accelerated operations.
#[default]
Cpu,
/// GPU device - uses fusor-core (wgpu) for GPU-accelerated operations.
Gpu(fusor_core::Device),
}
impl Device {
/// Create a new CPU device.
pub fn cpu() -> Self {
Device::Cpu
}
/// Create a new GPU device asynchronously.
///
/// This is an alias for `gpu()` to match the fusor-core API.
pub async fn new() -> Result<Self, Error> {
Self::gpu().await
}
/// Create a new GPU device asynchronously.
pub async fn gpu() -> Result<Self, Error> {
let device = fusor_core::Device::new().await?;
Ok(Device::Gpu(device))
}
/// Create a new GPU device, blocking until ready.
pub fn gpu_blocking() -> Result<Self, Error> {
pollster::block_on(Self::gpu())
}
/// Create a device, preferring GPU if available, otherwise falling back to CPU.
pub async fn auto() -> Self {
match Self::gpu().await {
Ok(gpu) => gpu,
Err(_) => Device::Cpu,
}
}
/// Returns true if this is a CPU device.
#[inline]
pub fn is_cpu(&self) -> bool {
matches!(self, Device::Cpu)
}
/// Returns true if this is a GPU device.
#[inline]
pub fn is_gpu(&self) -> bool {
matches!(self, Device::Gpu(_))
}
/// Returns a reference to the GPU device if this is a GPU device.
#[inline]
pub fn as_gpu(&self) -> Option<&fusor_core::Device> {
match self {
Device::Gpu(d) => Some(d),
_ => None,
}
}
/// Resolve multiple compute-graph nodes in a single pass. On GPU this
/// builds one shared execution graph so intermediate buffers can be freed
/// as soon as every consumer within the batch is computed, keeping peak
/// memory much lower than resolving one-by-one. On CPU this is a no-op.
pub fn resolve_batch(&self, keys: &[fusor_core::NodeIndex]) {
if let Device::Gpu(device) = self {
device.resolve_batch(keys);
}
}
}
impl PartialEq for Device {
fn eq(&self, other: &Self) -> bool {
match (self, other) {
(Device::Cpu, Device::Cpu) => true,
// GPU devices from the same Arc are equal
(Device::Gpu(_), Device::Gpu(_)) => true,
_ => false,
}
}
}