use pulp::Simd;
use crate::pairwise::SimdBinaryOp;
use crate::{ConcreteTensor, SimdElement, TensorBacking, materialize_expr};
use fusor_types::Layout;
macro_rules! define_cmp_marker {
($($name:ident),* $(,)?) => {
$(
#[derive(Copy, Clone)]
pub struct $name;
)*
};
}
define_cmp_marker!(EqOp, NeOp, LtOp, LteOp, GtOp, GteOp);
trait NumericBool: SimdElement {
fn zero() -> Self;
fn one() -> Self;
}
impl NumericBool for f32 {
fn zero() -> Self {
0.0
}
fn one() -> Self {
1.0
}
}
impl NumericBool for f64 {
fn zero() -> Self {
0.0
}
fn one() -> Self {
1.0
}
}
impl NumericBool for i8 {
fn zero() -> Self {
0
}
fn one() -> Self {
1
}
}
impl NumericBool for i16 {
fn zero() -> Self {
0
}
fn one() -> Self {
1
}
}
impl NumericBool for i32 {
fn zero() -> Self {
0
}
fn one() -> Self {
1
}
}
impl NumericBool for i64 {
fn zero() -> Self {
0
}
fn one() -> Self {
1
}
}
impl NumericBool for u8 {
fn zero() -> Self {
0
}
fn one() -> Self {
1
}
}
impl NumericBool for u16 {
fn zero() -> Self {
0
}
fn one() -> Self {
1
}
}
impl NumericBool for u32 {
fn zero() -> Self {
0
}
fn one() -> Self {
1
}
}
impl NumericBool for u64 {
fn zero() -> Self {
0
}
fn one() -> Self {
1
}
}
macro_rules! impl_scalar_comparison_op {
($op:ty, $cmp_fn:expr, $elem:ty) => {
impl SimdBinaryOp<$elem> for $op {
#[inline(always)]
fn apply_simd_vec<S: Simd>(
_simd: S,
a: <$elem as SimdElement>::Simd<S>,
b: <$elem as SimdElement>::Simd<S>,
) -> <$elem as SimdElement>::Simd<S> {
let lane_count = std::mem::size_of::<<$elem as SimdElement>::Simd<S>>()
/ std::mem::size_of::<$elem>();
let mut temp_out = [<$elem>::default(); crate::MAX_SIMD_LANES];
let slice_a: &[$elem] = pulp::bytemuck::cast_slice(std::slice::from_ref(&a));
let slice_b: &[$elem] = pulp::bytemuck::cast_slice(std::slice::from_ref(&b));
let cmp: fn($elem, $elem) -> bool = $cmp_fn;
for i in 0..lane_count {
temp_out[i] = if cmp(slice_a[i], slice_b[i]) {
<$elem as NumericBool>::one()
} else {
<$elem as NumericBool>::zero()
};
}
let (simd_slice, _) = <$elem as SimdElement>::as_simd::<S>(&temp_out[..lane_count]);
simd_slice[0]
}
#[inline(always)]
fn apply_scalar(a: $elem, b: $elem) -> $elem {
let cmp: fn($elem, $elem) -> bool = $cmp_fn;
if cmp(a, b) {
<$elem as NumericBool>::one()
} else {
<$elem as NumericBool>::zero()
}
}
}
};
}
macro_rules! impl_all_comparisons {
($($elem:ty),*) => {
$(
impl_scalar_comparison_op!(EqOp, |a: $elem, b: $elem| a == b, $elem);
impl_scalar_comparison_op!(NeOp, |a: $elem, b: $elem| a != b, $elem);
impl_scalar_comparison_op!(LtOp, |a: $elem, b: $elem| a < b, $elem);
impl_scalar_comparison_op!(LteOp, |a: $elem, b: $elem| a <= b, $elem);
impl_scalar_comparison_op!(GtOp, |a: $elem, b: $elem| a > b, $elem);
impl_scalar_comparison_op!(GteOp, |a: $elem, b: $elem| a >= b, $elem);
)*
};
}
impl_all_comparisons!(f32, f64, i8, i16, i32, i64, u8, u16, u32, u64);
macro_rules! define_comparison_tensor_op {
($name:ident, $simd_op:ty) => {
pub struct $name<
E: SimdElement,
const R: usize,
T1: TensorBacking<R, Elem = E>,
T2: TensorBacking<R, Elem = E>,
> {
lhs: T1,
rhs: T2,
_marker: std::marker::PhantomData<E>,
}
impl<E, const R: usize, T1, T2> $name<E, R, T1, T2>
where
E: SimdElement,
T1: TensorBacking<R, Elem = E>,
T2: TensorBacking<R, Elem = E>,
{
pub fn new(lhs: T1, rhs: T2) -> Self {
Self {
lhs,
rhs,
_marker: std::marker::PhantomData,
}
}
}
impl<E, const R: usize, T1, T2> crate::LazyBacking for $name<E, R, T1, T2>
where
E: SimdElement + Default,
$simd_op: SimdBinaryOp<E>,
T1: TensorBacking<R, Elem = E>,
T2: TensorBacking<R, Elem = E>,
{
type Elem = E;
#[inline(always)]
fn eval_scalar(&self, idx: usize) -> E {
<$simd_op>::apply_scalar(self.lhs.eval_scalar(idx), self.rhs.eval_scalar(idx))
}
#[inline(always)]
fn eval_simd<S: Simd>(&self, simd: S, base_idx: usize) -> E::Simd<S> {
<$simd_op>::apply_simd_vec(
simd,
self.lhs.eval_simd(simd, base_idx),
self.rhs.eval_simd(simd, base_idx),
)
}
}
impl<E, const R: usize, T1, T2> TensorBacking<R> for $name<E, R, T1, T2>
where
E: SimdElement + Default,
$simd_op: SimdBinaryOp<E>,
T1: TensorBacking<R, Elem = E>,
T2: TensorBacking<R, Elem = E>,
{
fn layout(&self) -> Layout {
Layout::contiguous(self.lhs.layout().shape())
}
fn to_concrete(&self) -> ConcreteTensor<E, R> {
let shape: [usize; R] = self
.lhs
.layout()
.shape()
.try_into()
.expect("Shape length mismatch");
materialize_expr(self, shape)
}
}
};
}
define_comparison_tensor_op!(Eq, EqOp);
define_comparison_tensor_op!(Ne, NeOp);
define_comparison_tensor_op!(Lt, LtOp);
define_comparison_tensor_op!(Lte, LteOp);
define_comparison_tensor_op!(Gt, GtOp);
define_comparison_tensor_op!(Gte, GteOp);