use std::ops::{Add as StdAdd, Div as StdDiv, Mul as StdMul, Sub as StdSub};
use pulp::Simd;
use crate::pairwise::{AddOp, DivOp, MulOp, SimdBinaryOp, SubOp};
use crate::{ConcreteTensor, SimdElement, TensorBacking, materialize_expr};
use fusor_types::Layout;
macro_rules! define_scalar_tensor_op {
($name:ident, $std_trait:ident, $simd_op:ty) => {
pub struct $name<E: SimdElement, const R: usize, T: TensorBacking<R, Elem = E>> {
tensor: T,
scalar: E,
}
impl<E, const R: usize, T> $name<E, R, T>
where
E: SimdElement,
T: TensorBacking<R, Elem = E>,
{
pub fn new(tensor: T, scalar: E) -> Self {
Self { tensor, scalar }
}
}
impl<E, const R: usize, T> crate::LazyBacking for $name<E, R, T>
where
E: SimdElement + $std_trait<Output = E> + Default,
$simd_op: SimdBinaryOp<E>,
T: TensorBacking<R, Elem = E>,
{
type Elem = E;
#[inline(always)]
fn eval_scalar(&self, idx: usize) -> E {
<$simd_op>::apply_scalar(self.tensor.eval_scalar(idx), self.scalar)
}
#[inline(always)]
fn eval_simd<S: Simd>(&self, simd: S, base_idx: usize) -> E::Simd<S> {
<$simd_op>::apply_simd_vec(
simd,
self.tensor.eval_simd(simd, base_idx),
E::splat(simd, self.scalar),
)
}
}
impl<E, const R: usize, T> TensorBacking<R> for $name<E, R, T>
where
E: SimdElement + $std_trait<Output = E> + Default,
$simd_op: SimdBinaryOp<E>,
T: TensorBacking<R, Elem = E>,
{
fn layout(&self) -> Layout {
Layout::contiguous(self.tensor.layout().shape())
}
fn to_concrete(&self) -> ConcreteTensor<E, R> {
let shape: [usize; R] = self
.tensor
.layout()
.shape()
.try_into()
.expect("Shape length mismatch");
materialize_expr(self, shape)
}
}
};
}
define_scalar_tensor_op!(AddScalar, StdAdd, AddOp);
define_scalar_tensor_op!(SubScalar, StdSub, SubOp);
define_scalar_tensor_op!(MulScalar, StdMul, MulOp);
define_scalar_tensor_op!(DivScalar, StdDiv, DivOp);
pub struct Broadcast<E: SimdElement, const R: usize> {
scalar: E,
shape: [usize; R],
}
impl<E: SimdElement, const R: usize> Broadcast<E, R> {
pub fn new(scalar: E, shape: [usize; R]) -> Self {
Self { scalar, shape }
}
}
impl<E: SimdElement + Default, const R: usize> crate::LazyBacking for Broadcast<E, R> {
type Elem = E;
#[inline(always)]
fn eval_scalar(&self, _idx: usize) -> E {
self.scalar
}
#[inline(always)]
fn eval_simd<S: Simd>(&self, simd: S, _base_idx: usize) -> E::Simd<S> {
E::splat(simd, self.scalar)
}
}
impl<E: SimdElement + Default, const R: usize> TensorBacking<R> for Broadcast<E, R> {
fn layout(&self) -> Layout {
Layout::contiguous(&self.shape)
}
fn to_concrete(&self) -> ConcreteTensor<E, R> {
materialize_expr(self, self.shape)
}
}