use std::{
error::Error,
fmt::{Debug, Display},
ops::Sub,
ops::{Range, RangeInclusive},
pin::Pin,
sync::Arc,
};
use fusor::{DataType, Device, FromArray, SimdElement, Tensor};
use half::f16;
use rand::{
RngCore, SeedableRng,
distr::{Distribution, StandardUniform, Uniform},
rngs::StdRng,
};
use thiserror::Error;
fn require_gpu() -> bool {
std::env::var("FUSOR_CONFORMANCE_REQUIRE_GPU")
.map(|value| {
let value = value.trim();
!value.is_empty() && value != "0" && !value.eq_ignore_ascii_case("false")
})
.unwrap_or(false)
}
pub async fn available_devices() -> Vec<Device> {
let mut devs = vec![Device::Cpu];
match Device::gpu().await {
Ok(gpu) => devs.push(gpu),
Err(err) => {
assert!(
!require_gpu(),
"GPU conformance is required but no GPU device was available: {err}"
);
}
}
devs
}
pub async fn f16_capable_devices() -> Vec<Device> {
available_devices()
.await
.into_iter()
.filter(|d| match d {
Device::Cpu => true,
Device::Gpu(gpu) => gpu.f16_supported(),
})
.collect()
}
fn index_iter<const R: usize>(shape: [usize; R]) -> impl Iterator<Item = [usize; R]> {
let total: usize = shape.iter().product();
(0..total).map(move |flat| {
let mut idx = [0usize; R];
let mut rem = flat;
for d in (0..R).rev() {
idx[d] = rem % shape[d];
rem /= shape[d];
}
idx
})
}
pub async fn eq_with<const R: usize, T: DataType + SimdElement>(
a: &Tensor<R, T>,
b: &Tensor<R, T>,
eq: impl Fn(T, T) -> bool,
) -> Result<(), ItemMismatchError> {
assert_eq!(a.shape(), b.shape(), "shape mismatch");
let shape = a.shape();
let sa = a.as_slice().await.unwrap();
let sb = b.as_slice().await.unwrap();
for index in index_iter(shape) {
let va = sa[index];
let vb = sb[index];
if !eq(va, vb) {
return Err(ItemMismatchError::new(
a.device(),
index,
format!("{:?}", va),
format!("{:?}", vb),
));
}
}
Ok(())
}
pub async fn approx_eq<const R: usize, T: Sub + PartialOrd + DataType + SimdElement>(
a: &Tensor<R, T>,
b: &Tensor<R, T>,
tol: T,
) -> Result<(), ItemMismatchError> {
eq_with(a, b, |va, vb| {
let diff = if va > vb { va - vb } else { vb - va };
diff <= tol
})
.await
}
pub async fn exact_eq<const R: usize, T: DataType + SimdElement + PartialEq>(
a: &Tensor<R, T>,
b: &Tensor<R, T>,
) -> Result<(), ItemMismatchError> {
eq_with(a, b, |va, vb| va == vb).await
}
pub async fn relative_eq<const R: usize>(
a: &Tensor<R, f32>,
b: &Tensor<R, f32>,
rel_tol: f32,
) -> Result<(), ItemMismatchError> {
eq_with(a, b, |va, vb| {
let diff = (va - vb).abs();
let scale = va.abs().max(vb.abs()).max(f32::MIN_POSITIVE);
diff <= rel_tol * scale
})
.await
}
pub async fn approx_or_relative_eq<const R: usize>(
a: &Tensor<R, f32>,
b: &Tensor<R, f32>,
abs_tol: f32,
rel_tol: f32,
) -> Result<(), ItemMismatchError> {
eq_with(a, b, |va, vb| {
let diff = (va - vb).abs();
let scale = va.abs().max(vb.abs()).max(f32::MIN_POSITIVE);
diff <= abs_tol || diff <= rel_tol * scale
})
.await
}
pub fn random_tensor<const R: usize, T: DataType + SimdElement>(
device: &Device,
shape: [usize; R],
rng: &mut StdRng,
sample: impl Fn(&mut StdRng) -> T,
) -> Tensor<R, T> {
let total: usize = shape.iter().product();
let data: Vec<T> = (0..total).map(|_| sample(rng)).collect();
Tensor::from_slice(device, shape, &data)
}
pub fn sequential_tensor<const R: usize, T: DataType + SimdElement + From<u16>>(
device: &Device,
shape: [usize; R],
) -> Tensor<R, T> {
let total: usize = shape.iter().product();
let data: Vec<T> = (0..total)
.map(|i| T::from(u16::try_from(i).expect("sequential tensor index fits in u16")))
.collect();
Tensor::from_slice(device, shape, &data)
}
#[derive(Clone, Debug)]
pub enum FuzzSizeSpec {
Fixed(usize),
Choices(Arc<[usize]>),
Range { start: usize, end_exclusive: usize },
}
impl FuzzSizeSpec {
fn sample(&self, rng: &mut StdRng) -> usize {
match self {
FuzzSizeSpec::Fixed(size) => *size,
FuzzSizeSpec::Choices(choices) => {
assert!(
!choices.is_empty(),
"fuzz size choice list must contain at least one size"
);
let index = (rng.next_u64() as usize) % choices.len();
choices[index]
}
FuzzSizeSpec::Range {
start,
end_exclusive,
} => {
assert!(
start < end_exclusive,
"fuzz size range must not be empty: {start}..{end_exclusive}"
);
Uniform::new(*start, *end_exclusive)
.expect("validated non-empty size range")
.sample(rng)
}
}
}
}
impl From<usize> for FuzzSizeSpec {
fn from(value: usize) -> Self {
Self::Fixed(value)
}
}
impl<const N: usize> From<[usize; N]> for FuzzSizeSpec {
fn from(value: [usize; N]) -> Self {
Self::Choices(Arc::from(value))
}
}
impl From<Vec<usize>> for FuzzSizeSpec {
fn from(value: Vec<usize>) -> Self {
Self::Choices(Arc::from(value.into_boxed_slice()))
}
}
impl From<Box<[usize]>> for FuzzSizeSpec {
fn from(value: Box<[usize]>) -> Self {
Self::Choices(Arc::from(value))
}
}
impl From<Range<usize>> for FuzzSizeSpec {
fn from(value: Range<usize>) -> Self {
Self::Range {
start: value.start,
end_exclusive: value.end,
}
}
}
impl From<RangeInclusive<usize>> for FuzzSizeSpec {
fn from(value: RangeInclusive<usize>) -> Self {
let (start, end) = value.into_inner();
Self::Range {
start,
end_exclusive: end
.checked_add(1)
.expect("inclusive fuzz size range upper bound overflowed"),
}
}
}
pub trait IntoFuzzShape<const R: usize> {
fn into_shape_specs(self) -> [FuzzSizeSpec; R];
}
impl<const R: usize> IntoFuzzShape<R> for [usize; R] {
fn into_shape_specs(self) -> [FuzzSizeSpec; R] {
self.map(FuzzSizeSpec::from)
}
}
impl<const R: usize> IntoFuzzShape<R> for [FuzzSizeSpec; R] {
fn into_shape_specs(self) -> [FuzzSizeSpec; R] {
self
}
}
impl<const R: usize, const N: usize> IntoFuzzShape<R> for [[usize; N]; R] {
fn into_shape_specs(self) -> [FuzzSizeSpec; R] {
self.map(FuzzSizeSpec::from)
}
}
impl<const R: usize> IntoFuzzShape<R> for [Range<usize>; R] {
fn into_shape_specs(self) -> [FuzzSizeSpec; R] {
self.map(FuzzSizeSpec::from)
}
}
impl<const R: usize> IntoFuzzShape<R> for [RangeInclusive<usize>; R] {
fn into_shape_specs(self) -> [FuzzSizeSpec; R] {
self.map(FuzzSizeSpec::from)
}
}
#[derive(Clone)]
pub struct FuzzGenerator<const R: usize, T: SimdElement> {
value_seed: u64,
shape_seed: u64,
distribution: Arc<dyn Fn(&mut rand::rngs::StdRng) -> T + Send + Sync>,
shape_specs: [FuzzSizeSpec; R],
phantom: std::marker::PhantomData<T>,
}
impl<const R: usize, T: SimdElement + DataType> FuzzGenerator<R, T> {
pub fn new(shape: impl IntoFuzzShape<R>) -> Self
where
StandardUniform: rand::distr::Distribution<T>,
{
Self::with_sampler(shape, |rng| StandardUniform.sample(rng))
}
pub fn with_sampler(
shape: impl IntoFuzzShape<R>,
sampler: impl Fn(&mut StdRng) -> T + Send + Sync + 'static,
) -> Self {
Self {
value_seed: 0,
shape_seed: 0,
distribution: Arc::new(sampler),
shape_specs: shape.into_shape_specs(),
phantom: std::marker::PhantomData,
}
}
pub fn with_seed(mut self, seed: u64) -> Self {
self.value_seed = seed;
self
}
pub fn with_rng(mut self, mut rng: impl RngCore) -> Self {
self.value_seed = rng.next_u64();
self
}
pub fn with_shape_seed(mut self, seed: u64) -> Self {
self.shape_seed = seed;
self
}
pub fn with_distribution(
mut self,
distribution: impl Distribution<T> + Send + Sync + 'static,
) -> Self {
self.distribution = Arc::new(move |rng| distribution.sample(rng));
self
}
fn value_seed_for_run(&self, run: usize) -> u64 {
self.value_seed
^ (run as u64).wrapping_mul(0x9E37_79B9_7F4A_7C15)
^ (R as u64).wrapping_mul(0xBF58_476D_1CE4_E5B9)
}
fn shape_seed_for_run(&self, run: usize) -> u64 {
self.shape_seed
^ (run as u64).wrapping_mul(0xD6E8_FEB8_6659_FD93)
^ (R as u64).wrapping_mul(0x94D0_49BB_1331_11EB)
}
fn sample_shape(&self, rng: &mut StdRng) -> [usize; R] {
self.shape_specs
.clone()
.map(|shape_spec| shape_spec.sample(rng))
}
#[cfg(test)]
fn shape_for_run(&self, run: usize) -> [usize; R] {
let mut rng = rand::rngs::StdRng::seed_from_u64(self.shape_seed_for_run(run));
self.sample_shape(&mut rng)
}
fn generate_for_run(&self, device: &Device, run: usize) -> Tensor<R, T> {
let mut shape_rng = rand::rngs::StdRng::seed_from_u64(self.shape_seed_for_run(run));
let shape = self.sample_shape(&mut shape_rng);
let mut rng = rand::rngs::StdRng::seed_from_u64(self.value_seed_for_run(run));
let base = random_tensor(device, shape, &mut rng, &*self.distribution);
let strategy = run % 3;
match strategy {
0 => base,
1 => make_transposed(base, &mut rng, &*self.distribution),
_ => make_sliced(base, &mut rng, &*self.distribution),
}
}
}
fn make_transposed<const R: usize, T: SimdElement + DataType + Default>(
tensor: Tensor<R, T>,
rng: &mut StdRng,
sample: &dyn Fn(&mut StdRng) -> T,
) -> Tensor<R, T> {
if R < 2 {
return tensor;
}
let shape = tensor.shape();
let transposed_shape: [usize; R] = std::array::from_fn(|i| {
if i == R - 2 {
shape[R - 1]
} else if i == R - 1 {
shape[R - 2]
} else {
shape[i]
}
});
let device = tensor.device();
let contiguous = random_tensor(&device, transposed_shape, rng, sample);
contiguous.transpose(R - 2, R - 1).to_concrete()
}
fn make_sliced<const R: usize, T: SimdElement + DataType + Default>(
tensor: Tensor<R, T>,
rng: &mut StdRng,
sample: &dyn Fn(&mut StdRng) -> T,
) -> Tensor<R, T> {
if R == 0 {
return tensor;
}
let shape = tensor.shape();
let pad_dim = R - 1;
let pad = 1;
let padded_size = shape[pad_dim] + pad;
let padded_shape: [usize; R] =
std::array::from_fn(|i| if i == pad_dim { padded_size } else { shape[i] });
let device = tensor.device();
let padded = random_tensor(&device, padded_shape, rng, sample);
padded.narrow(pad_dim, pad, shape[pad_dim]).to_concrete()
}
impl<const R: usize> FuzzGenerator<R, f32> {
pub fn with_positive(mut self) -> Self {
self.distribution =
Arc::new(move |rng| Uniform::new(0.0, 1.0).expect("0.0 < 1.0").sample(rng));
self
}
}
#[doc(hidden)]
pub trait GenerateFromDevice {
type Output;
fn generate(&mut self, device: &Device, run: usize) -> Self::Output;
}
impl<F, O> GenerateFromDevice for F
where
F: FnMut(&Device) -> O,
{
type Output = O;
fn generate(&mut self, device: &Device, _run: usize) -> Self::Output {
(self)(device)
}
}
impl<const R: usize, T: SimdElement + DataType> GenerateFromDevice for FuzzGenerator<R, T> {
type Output = Tensor<R, T>;
fn generate(&mut self, device: &Device, run: usize) -> Self::Output {
self.generate_for_run(device, run)
}
}
#[doc(hidden)]
pub trait AsyncFnMutTuple<Args> {
type Output;
fn call_mut<'a>(
&'a mut self,
args: Args,
) -> Pin<Box<dyn std::future::Future<Output = Self::Output> + Send + 'a>>;
}
macro_rules! impl_fn_mut_tuple {
($($type:ident),*) => {
impl<Fn, U, Fut, $($type),*> AsyncFnMutTuple<($($type,)*)> for Fn
where
Fn: FnMut($($type,)*) -> Fut,
Fut: std::future::Future<Output = U> + Send + 'static,
{
type Output = U;
#[allow(non_snake_case)]
fn call_mut<'a>(&'a mut self, ($($type,)*): ($($type,)*)) -> Pin<Box<dyn std::future::Future<Output = Self::Output> + Send + 'a>> {
Box::pin((self)($($type,)*))
}
}
};
}
impl_fn_mut_tuple!();
impl_fn_mut_tuple!(A);
impl_fn_mut_tuple!(A, B);
impl_fn_mut_tuple!(A, B, C);
impl_fn_mut_tuple!(A, B, C, D);
impl_fn_mut_tuple!(A, B, C, D, E);
impl_fn_mut_tuple!(A, B, C, D, E, F);
#[doc(hidden)]
pub trait GenTuple {
type Output;
fn generate(&mut self, device: &Device, run: usize) -> Self::Output;
}
macro_rules! impl_gen_tuple {
($($type:ident -> $type_out:ident),*) => {
impl<$($type, $type_out),*> GenTuple for ($($type,)*)
where
$(
$type: GenerateFromDevice<Output = $type_out>,
)*
{
type Output = ($($type_out,)*);
#[allow(non_snake_case)]
fn generate(&mut self, device: &Device, run: usize) -> Self::Output {
let ($($type,)*) = self;
(
$(
$type.generate(device, run),
)*
)
}
}
};
}
impl_gen_tuple!(A -> AOut);
impl_gen_tuple!(A -> AOut, B -> BOut);
impl_gen_tuple!(A -> AOut, B -> BOut, C -> COut);
impl_gen_tuple!(A -> AOut, B -> BOut, C -> COut, D -> DOut);
impl_gen_tuple!(A -> AOut, B -> BOut, C -> COut, D -> DOut, E -> EOut);
impl_gen_tuple!(A -> AOut, B -> BOut, C -> COut, D -> DOut, E -> EOut, F -> FOut);
#[doc(hidden)]
pub trait ResolveTensorTuple {
type Output;
fn resolve(self) -> impl Future<Output = Result<Self::Output, fusor::Error>> + Send + 'static;
fn extract_device(&self) -> Device;
}
macro_rules! impl_resolve_tensor_tuple {
($($type:ident = $rank:ident),*) => {
impl<$(const $rank: usize,)* $($type: DataType + SimdElement),*> ResolveTensorTuple for ($(Tensor<$rank, $type>,)*)
where
$(
fusor_types::TensorSlice<$rank, $type, fusor::EitherMappedBuffer>: fusor::ToVec<Output: Send + 'static>,
)*
{
type Output = ($(<fusor_types::TensorSlice<$rank, $type, fusor::EitherMappedBuffer> as fusor::ToVec>::Output,)*);
#[allow(non_snake_case)]
async fn resolve(self) -> Result<Self::Output, fusor::Error> {
let ($($type,)*) = self;
Ok((
$(
fusor::ToVec::to_vec(&$type.as_slice().await?),
)*
))
}
#[allow(unused)]
#[allow(non_snake_case)]
fn extract_device(&self) -> Device {
let ($($type,)*) = self;
$(
let device = $type.device();
return device;
)*
unreachable!()
}
}
};
}
impl_resolve_tensor_tuple!(A = N1);
impl_resolve_tensor_tuple!(A = N1, B = N2);
impl_resolve_tensor_tuple!(A = N1, B = N2, C = N3);
impl_resolve_tensor_tuple!(A = N1, B = N2, C = N3, D = N4);
impl_resolve_tensor_tuple!(A = N1, B = N2, C = N3, D = N4, E = N5);
impl_resolve_tensor_tuple!(A = N1, B = N2, C = N3, D = N4, E = N5, F = N6);
#[doc(hidden)]
pub trait PushTuple<Tail> {
type Output;
fn push(self, new_last: Tail) -> Self::Output;
}
#[doc(hidden)]
pub trait PopTuple {
type First;
type Rest;
fn pop(self) -> (Self::First, Self::Rest);
}
macro_rules! impl_push_pop_tuple {
($first_type:ident $(,$type:ident)*) => {
impl<$first_type $(,$type)*> PopTuple for ($first_type, $($type,)*) {
type First = $first_type;
type Rest = ($($type,)*);
#[allow(non_snake_case)]
fn pop(self) -> (Self::First, Self::Rest) {
let ($first_type, $($type,)*) = self;
($first_type, ($($type,)*))
}
}
impl<$first_type, $($type,)* Tail> PushTuple<Tail> for ($first_type, $($type,)*) {
type Output = ($first_type, $($type,)* Tail);
#[allow(non_snake_case)]
fn push(self, new_last: Tail) -> Self::Output {
let (head, $($type,)*) = self;
(head, $($type,)* new_last)
}
}
};
}
impl PopTuple for () {
type First = ();
type Rest = ();
fn pop(self) -> (Self::First, Self::Rest) {
((), ())
}
}
impl<Tail> PushTuple<Tail> for () {
type Output = (Tail,);
fn push(self, new_last: Tail) -> Self::Output {
(new_last,)
}
}
impl_push_pop_tuple!(A);
impl_push_pop_tuple!(A, B);
impl_push_pop_tuple!(A, B, C);
impl_push_pop_tuple!(A, B, C, D);
impl_push_pop_tuple!(A, B, C, D, E);
impl_push_pop_tuple!(A, B, C, D, E, F);
#[derive(Error, Debug)]
pub struct ItemMismatchError {
device: Device,
position: Vec<usize>,
expected: String,
actual: String,
}
impl ItemMismatchError {
pub fn new(
device: Device,
position: impl IntoIterator<Item = usize>,
expected: impl ToString,
actual: impl ToString,
) -> Self {
Self {
device,
position: position.into_iter().collect(),
expected: expected.to_string(),
actual: actual.to_string(),
}
}
}
impl Display for ItemMismatchError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let position = if self.position.is_empty() {
String::from("<scalar>")
} else {
format!("{:?}", self.position)
};
write!(
f,
"Item mismatch on device {:?} at {}: expected {}, got {}",
self.device, position, self.expected, self.actual
)
}
}
pub struct AssertBuilder<T, U, Generators = (), Compare = ()> {
baseline: Box<dyn AsyncFnMutTuple<T, Output = U>>,
to_validate: Vec<Box<dyn AsyncFnMutTuple<T, Output = U>>>,
generators: Generators,
compare: Compare,
devices: Option<Vec<Device>>,
runs: usize,
}
impl<T, U> AssertBuilder<T, U> {
fn new(op: impl AsyncFnMutTuple<T, Output = U> + 'static) -> Self {
Self {
baseline: Box::new(op),
to_validate: Vec::new(),
generators: (),
compare: (),
devices: None,
runs: 5,
}
}
}
impl<T, U, Generators, Compare> AssertBuilder<T, U, Generators, Compare> {
pub fn arg<Gen, O>(self, g: Gen) -> AssertBuilder<T, U, Generators::Output, Compare>
where
Generators: PushTuple<Gen>,
Gen: GenerateFromDevice<Output = O>,
{
AssertBuilder {
baseline: self.baseline,
to_validate: self.to_validate,
generators: self.generators.push(g),
compare: self.compare,
devices: self.devices,
runs: self.runs,
}
}
pub fn compare_with<Cmp>(self, cmp: Cmp) -> AssertBuilder<T, U, Generators, Cmp>
where
Cmp: IntoCompare<U>,
{
AssertBuilder {
baseline: self.baseline,
to_validate: self.to_validate,
generators: self.generators,
compare: cmp,
devices: self.devices,
runs: self.runs,
}
}
pub fn runs(mut self, runs: usize) -> Self {
self.runs = runs;
self
}
pub fn devices(mut self, devices: impl IntoIterator<Item = Device>) -> Self {
self.devices = Some(devices.into_iter().collect());
self
}
pub fn equal_to(mut self, other: impl AsyncFnMutTuple<T, Output = U> + 'static) -> Self {
self.to_validate.push(Box::new(other));
self
}
pub fn equal_to_resolved_op(
self,
mut other: impl AsyncFnMutTuple<T::Output, Output = U> + Copy + Send + 'static,
) -> Self
where
T: ResolveTensorTuple,
{
struct UnpackedTuple<T>(T);
impl<F, Fut, I, O> AsyncFnMutTuple<I> for UnpackedTuple<F>
where
F: FnMut(I) -> Fut,
Fut: std::future::Future<Output = O> + Send + 'static,
{
type Output = O;
fn call_mut<'a>(
&'a mut self,
input: I,
) -> Pin<Box<dyn std::future::Future<Output = Self::Output> + Send + 'a>> {
Box::pin((self.0)(input))
}
}
let wrapped = move |input: T| {
let input = input.resolve();
async move {
let input = input.await.unwrap();
other.call_mut(input).await
}
};
self.equal_to(UnpackedTuple(wrapped))
}
pub fn equal_to_resolved_with_device(
self,
mut other: impl AsyncFnMutTuple<<T::Output as PushTuple<Device>>::Output, Output = U>
+ Copy
+ Send
+ 'static,
) -> Self
where
T: ResolveTensorTuple,
T::Output: PushTuple<Device>,
{
struct UnpackedTuple<T>(T);
impl<F, Fut, I, O> AsyncFnMutTuple<I> for UnpackedTuple<F>
where
F: FnMut(I) -> Fut,
Fut: std::future::Future<Output = O> + Send + 'static,
{
type Output = O;
fn call_mut<'a>(
&'a mut self,
input: I,
) -> Pin<Box<dyn std::future::Future<Output = Self::Output> + Send + 'a>> {
Box::pin((self.0)(input))
}
}
let wrapped = move |input: T| {
let device = input.extract_device();
let input = input.resolve();
async move {
let input = input.await.unwrap();
other.call_mut(input.push(device)).await
}
};
self.equal_to(UnpackedTuple(wrapped))
}
pub fn equal_to_array_op<const R: usize, D, A>(
self,
mut other: impl AsyncFnMutTuple<T::Output, Output = A> + Copy + Send + 'static,
) -> Self
where
T: ResolveTensorTuple,
for<'a> U: FromArray<R, D, &'a A, Device>,
{
struct UnpackedTuple<T>(T);
impl<F, Fut, I, O> AsyncFnMutTuple<I> for UnpackedTuple<F>
where
F: FnMut(I) -> Fut,
Fut: std::future::Future<Output = O> + Send + 'static,
{
type Output = O;
fn call_mut<'a>(
&'a mut self,
input: I,
) -> Pin<Box<dyn std::future::Future<Output = Self::Output> + Send + 'a>> {
Box::pin((self.0)(input))
}
}
let wrapped = move |input: T| {
let device = input.extract_device();
let input = input.resolve();
async move {
let input = input.await.unwrap();
let output = other.call_mut(input).await;
U::from_array(&output, &device)
}
};
self.equal_to(UnpackedTuple(wrapped))
}
}
impl<T, U, Generators, Compare> IntoFuture for AssertBuilder<T, U, Generators, Compare>
where
Generators: GenTuple<Output = T> + 'static,
Compare: IntoCompare<U>,
T: Clone + 'static,
U: Clone + 'static,
{
type Output = Result<(), Compare::Error>;
type IntoFuture = Pin<Box<dyn std::future::Future<Output = Self::Output>>>;
fn into_future(mut self) -> Self::IntoFuture {
let compare_fn = self.compare.into_compare();
let future = async move {
let devices = if let Some(devs) = self.devices {
devs
} else {
available_devices().await
};
for run in 0..self.runs {
for device in &devices {
let args = self.generators.generate(device, run);
let expected = self.baseline.call_mut(args.clone()).await;
for to_validate in &mut self.to_validate {
let actual = to_validate.call_mut(args.clone()).await;
compare_fn(&expected, &actual).await?;
}
}
}
Ok(())
};
Box::pin(future)
}
}
pub type CompareFut<'a, E> = Pin<Box<dyn std::future::Future<Output = Result<(), E>> + 'a>>;
#[doc(hidden)]
pub trait IntoCompare<U> {
type Error: Error;
fn into_compare(self)
-> impl for<'a> Fn(&'a U, &'a U) -> CompareFut<'a, Self::Error> + 'static;
}
impl<U, Cmp, E: Error> IntoCompare<U> for Cmp
where
Cmp: for<'a> Fn(&'a U, &'a U) -> CompareFut<'a, E> + 'static,
{
type Error = E;
fn into_compare(
self,
) -> impl for<'a> Fn(&'a U, &'a U) -> CompareFut<'a, Self::Error> + 'static {
self
}
}
impl<const R: usize> IntoCompare<Tensor<R, u32>> for () {
type Error = ItemMismatchError;
fn into_compare(
self,
) -> impl for<'a> Fn(&'a Tensor<R, u32>, &'a Tensor<R, u32>) -> CompareFut<'a, Self::Error> + 'static
{
|a, b| Box::pin(exact_eq(a, b))
}
}
impl<const R: usize> IntoCompare<Tensor<R, f32>> for () {
type Error = ItemMismatchError;
fn into_compare(
self,
) -> impl for<'a> Fn(&'a Tensor<R, f32>, &'a Tensor<R, f32>) -> CompareFut<'a, Self::Error> + 'static
{
|a, b| Box::pin(approx_eq(a, b, 1e-5))
}
}
impl<const R: usize> IntoCompare<Tensor<R, f16>> for () {
type Error = ItemMismatchError;
fn into_compare(
self,
) -> impl for<'a> Fn(&'a Tensor<R, f16>, &'a Tensor<R, f16>) -> CompareFut<'a, Self::Error> + 'static
{
|a, b| Box::pin(approx_eq(a, b, f16::from_f32(1e-3)))
}
}
pub fn exact_compare<const R: usize, T>()
-> impl for<'a> Fn(&'a Tensor<R, T>, &'a Tensor<R, T>) -> CompareFut<'a, ItemMismatchError> + Clone
where
T: DataType + SimdElement + PartialEq,
{
|a, b| Box::pin(exact_eq(a, b))
}
pub fn approx_compare<const R: usize, T>(
tol: T,
) -> impl for<'a> Fn(&'a Tensor<R, T>, &'a Tensor<R, T>) -> CompareFut<'a, ItemMismatchError> + Clone
where
T: Sub<Output = T> + PartialOrd + DataType + SimdElement + Copy,
{
move |a, b| Box::pin(approx_eq(a, b, tol))
}
pub fn relative_compare<const R: usize>(
rel_tol: f32,
) -> impl for<'a> Fn(&'a Tensor<R, f32>, &'a Tensor<R, f32>) -> CompareFut<'a, ItemMismatchError> + Clone
{
move |a, b| Box::pin(relative_eq(a, b, rel_tol))
}
pub fn approx_or_relative_compare<const R: usize>(
abs_tol: f32,
rel_tol: f32,
) -> impl for<'a> Fn(&'a Tensor<R, f32>, &'a Tensor<R, f32>) -> CompareFut<'a, ItemMismatchError> + Clone
{
move |a, b| Box::pin(approx_or_relative_eq(a, b, abs_tol, rel_tol))
}
pub fn assert<T, U>(op: impl AsyncFnMutTuple<T, Output = U> + 'static) -> AssertBuilder<T, U> {
AssertBuilder::new(op)
}
#[cfg(test)]
mod api_tests {
use fusor::{Device, Tensor};
use crate::{FuzzGenerator, FuzzSizeSpec};
#[tokio::test]
async fn test_api() {
crate::assert(async |x: fusor::Tensor<1, f32>| x.sin().to_concrete())
.arg(FuzzGenerator::<1, f32>::new([63..=65]))
.equal_to_resolved_with_device(async |vec: Vec<f32>, device: Device| {
let expected = vec.iter().map(|&v| v.sin()).collect::<Vec<_>>();
Tensor::new(&device, &expected)
})
.runs(10)
.await
.unwrap();
}
#[test]
fn fuzz_generator_accepts_size_choices_and_ranges() {
let choice_generator =
FuzzGenerator::<2, f32>::new([[255, 256, 257], [31, 32, 33]]).with_seed(1234);
for run in 0..24 {
let shape = choice_generator.shape_for_run(run);
assert!([255, 256, 257].contains(&shape[0]));
assert!([31, 32, 33].contains(&shape[1]));
}
let range_generator = FuzzGenerator::<2, f32>::new([255..=257, 31..=33]).with_seed(5678);
for run in 0..24 {
let shape = range_generator.shape_for_run(run);
assert!((255..=257).contains(&shape[0]));
assert!((31..=33).contains(&shape[1]));
}
let mixed_generator = FuzzGenerator::<2, f32>::new([
FuzzSizeSpec::from([255, 256, 257]),
FuzzSizeSpec::from(63..=65),
])
.with_seed(9012);
for run in 0..24 {
let shape = mixed_generator.shape_for_run(run);
assert!([255, 256, 257].contains(&shape[0]));
assert!((63..=65).contains(&shape[1]));
}
}
#[test]
fn fuzz_generator_shapes_do_not_depend_on_value_seed() {
let first = FuzzGenerator::<2, f32>::new([255..=257, 63..=65]).with_seed(1);
let second = FuzzGenerator::<2, f32>::new([255..=257, 63..=65]).with_seed(2);
for run in 0..24 {
assert_eq!(first.shape_for_run(run), second.shape_for_run(run));
}
}
}