//! Conditional tensor operations: where_cond
//! Selects elements based on condition tensor != 0
use crate::expr::linear_to_indices;
use crate::{ConcreteTensor, ResolvedTensor, SimdElement};
/// Helper trait for types that can be compared to zero
pub trait IsNonZero: SimdElement {
fn is_nonzero(&self) -> bool;
}
macro_rules! impl_is_nonzero {
($($ty:ty => $zero:expr),*) => {
$(
impl IsNonZero for $ty {
fn is_nonzero(&self) -> bool {
*self != $zero
}
}
)*
};
}
impl_is_nonzero!(
f32 => 0.0, f64 => 0.0,
i8 => 0, i16 => 0, i32 => 0, i64 => 0,
u8 => 0, u16 => 0, u32 => 0, u64 => 0
);
/// Conditional selection: where condition != 0, select on_true, else on_false
#[inline(always)]
pub(crate) fn where_cond_ref<E, const R: usize>(
cond: &ConcreteTensor<E, R>,
on_true: &ConcreteTensor<E, R>,
on_false: &ConcreteTensor<E, R>,
) -> ConcreteTensor<E, R>
where
E: SimdElement + IsNonZero,
{
let shape: [usize; R] = cond
.layout()
.shape()
.try_into()
.expect("Shape length mismatch");
debug_assert_eq!(
cond.layout().shape(),
on_true.layout().shape(),
"where_cond: cond and on_true shape mismatch"
);
debug_assert_eq!(
cond.layout().shape(),
on_false.layout().shape(),
"where_cond: cond and on_false shape mismatch"
);
let all_contiguous = cond.layout().is_contiguous()
&& on_true.layout().is_contiguous()
&& on_false.layout().is_contiguous();
if all_contiguous {
let cond_data = cond.data();
let true_data = on_true.data();
let false_data = on_false.data();
ConcreteTensor::from_fn(shape, |i| {
if cond_data[i].is_nonzero() {
true_data[i]
} else {
false_data[i]
}
})
} else {
ConcreteTensor::from_fn(shape, |out_idx| {
let indices = linear_to_indices::<R>(out_idx, &shape);
let cond_idx = cond.layout().linear_index(&indices);
let true_idx = on_true.layout().linear_index(&indices);
let false_idx = on_false.layout().linear_index(&indices);
let cond_val = cond.data()[cond_idx];
if cond_val.is_nonzero() {
on_true.data()[true_idx]
} else {
on_false.data()[false_idx]
}
})
}
}