mod common;
use common::{reshape4, rope_interleaved_4d, rope_normal_4d};
use fusor::{Device, RopeCache, Tensor, ToVec1, base_inverse_frequency};
use fusor_conformance::{FuzzGenerator, approx_compare};
use rand::distr::Uniform;
fn rope_tables(
head_dim: usize,
context_length: usize,
theta: f32,
) -> (Vec<Vec<f32>>, Vec<Vec<f32>>) {
let inv_freq = base_inverse_frequency(head_dim, theta);
let cos = (0..context_length)
.map(|position| {
inv_freq
.iter()
.map(|freq| (position as f32 * freq).cos())
.collect()
})
.collect();
let sin = (0..context_length)
.map(|position| {
inv_freq
.iter()
.map(|freq| (position as f32 * freq).sin())
.collect()
})
.collect();
(cos, sin)
}
static COS: &[[f32; 2]] = &[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
static SIN: &[[f32; 2]] = &[[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]];
fn cos_vec() -> Vec<Vec<f32>> {
COS.iter().map(|r| r.to_vec()).collect()
}
fn sin_vec() -> Vec<Vec<f32>> {
SIN.iter().map(|r| r.to_vec()).collect()
}
#[tokio::test]
async fn rope_and_cache_paths_match_reference_variants() {
let expected = vec![
1.0,
1.0 / 10_000.0f32.powf(2.0 / 8.0),
1.0 / 10_000.0f32.powf(4.0 / 8.0),
1.0 / 10_000.0f32.powf(6.0 / 8.0),
];
assert_eq!(base_inverse_frequency(8, 10_000.0), expected);
let cos = cos_vec();
let sin = sin_vec();
let fuzz_input = FuzzGenerator::<4, f32>::new([1, 2, 3, 4])
.with_seed(600)
.with_distribution(Uniform::new(-6.0, 6.0).unwrap());
fusor_conformance::assert({
let cos = cos.clone();
let sin = sin.clone();
move |x: Tensor<4, f32>| {
let cos = cos.clone();
let sin = sin.clone();
async move {
let device = x.device();
let slice = x
.to_concrete()
.flatten_all()
.to_concrete()
.as_slice()
.await
.unwrap();
let flat: Vec<f32> = slice.to_vec1();
let host = reshape4(&flat, [1, 2, 3, 4]);
Tensor::new(&device, &rope_normal_4d(&host, &cos, &sin))
}
}
})
.arg(fuzz_input.clone())
.equal_to({
let cos = cos.clone();
let sin = sin.clone();
move |x: Tensor<4, f32>| {
let cos = cos.clone();
let sin = sin.clone();
async move {
let device = x.device();
let cos_t: Tensor<2, f32> = Tensor::new(&device, &cos);
let sin_t: Tensor<2, f32> = Tensor::new(&device, &sin);
x.rope(&cos_t, &sin_t)
}
}
})
.compare_with(approx_compare::<4, f32>(1e-4))
.runs(3)
.await
.unwrap();
fusor_conformance::assert({
let cos = cos.clone();
let sin = sin.clone();
move |x: Tensor<4, f32>| {
let cos = cos.clone();
let sin = sin.clone();
async move {
let device = x.device();
let slice = x
.to_concrete()
.flatten_all()
.to_concrete()
.as_slice()
.await
.unwrap();
let flat: Vec<f32> = slice.to_vec1();
let host = reshape4(&flat, [1, 2, 3, 4]);
Tensor::new(&device, &rope_interleaved_4d(&host, &cos, &sin))
}
}
})
.arg(fuzz_input.clone())
.equal_to({
let cos = cos.clone();
let sin = sin.clone();
move |x: Tensor<4, f32>| {
let cos = cos.clone();
let sin = sin.clone();
async move {
let device = x.device();
let cos_t: Tensor<2, f32> = Tensor::new(&device, &cos);
let sin_t: Tensor<2, f32> = Tensor::new(&device, &sin);
x.rope_interleaved(&cos_t, &sin_t)
}
}
})
.compare_with(approx_compare::<4, f32>(1e-4))
.runs(3)
.await
.unwrap();
fusor_conformance::assert({
let cos = cos.clone();
let sin = sin.clone();
move |x: Tensor<4, f32>| {
let cos = cos.clone();
let sin = sin.clone();
async move {
let device = x.device();
let cos_t: Tensor<2, f32> = Tensor::new(&device, &cos);
let sin_t: Tensor<2, f32> = Tensor::new(&device, &sin);
x.rope_normal_fused(&cos_t, &sin_t)
}
}
})
.arg(fuzz_input.clone())
.equal_to({
let cos = cos.clone();
let sin = sin.clone();
move |x: Tensor<4, f32>| {
let cos = cos.clone();
let sin = sin.clone();
async move {
let device = x.device();
let cos_t: Tensor<2, f32> = Tensor::new(&device, &cos);
let sin_t: Tensor<2, f32> = Tensor::new(&device, &sin);
x.rope(&cos_t, &sin_t)
}
}
})
.compare_with(approx_compare::<4, f32>(1e-4))
.runs(3)
.await
.unwrap();
fusor_conformance::assert({
let cos = cos.clone();
let sin = sin.clone();
move |x: Tensor<4, f32>| {
let cos = cos.clone();
let sin = sin.clone();
async move {
let device = x.device();
let cos_t: Tensor<2, f32> = Tensor::new(&device, &cos);
let sin_t: Tensor<2, f32> = Tensor::new(&device, &sin);
x.rope_fused(&cos_t, &sin_t)
}
}
})
.arg(fuzz_input.clone())
.equal_to({
let cos = cos.clone();
let sin = sin.clone();
move |x: Tensor<4, f32>| {
let cos = cos.clone();
let sin = sin.clone();
async move {
let device = x.device();
let cos_t: Tensor<2, f32> = Tensor::new(&device, &cos);
let sin_t: Tensor<2, f32> = Tensor::new(&device, &sin);
x.rope_interleaved(&cos_t, &sin_t)
}
}
})
.compare_with(approx_compare::<4, f32>(1e-4))
.runs(3)
.await
.unwrap();
let cache_cos = vec![
vec![1.0f32, 1.1],
vec![1.2, 1.3],
vec![1.4, 1.5],
vec![1.6, 1.7],
];
let cache_sin = vec![
vec![0.1f32, 0.2],
vec![0.3, 0.4],
vec![0.5, 0.6],
vec![0.7, 0.8],
];
let gen_q = FuzzGenerator::<4, f32>::new([1, 2, 3, 4])
.with_seed(601)
.with_distribution(Uniform::new(-6.0, 6.0).unwrap());
let gen_k = FuzzGenerator::<4, f32>::new([1, 2, 3, 4])
.with_seed(602)
.with_distribution(Uniform::new(-6.0, 6.0).unwrap());
fusor_conformance::assert({
let cache_cos = cache_cos.clone();
let cache_sin = cache_sin.clone();
move |q: Tensor<4, f32>, k: Tensor<4, f32>| {
let cache_cos = cache_cos.clone();
let cache_sin = cache_sin.clone();
async move {
let device = q.device();
let cache = RopeCache::from_parts(
Tensor::new(&device, &cache_cos),
Tensor::new(&device, &cache_sin),
);
cache.forward(&q, &k, 1).0
}
}
})
.arg(gen_q.clone())
.arg(gen_k.clone())
.equal_to({
let cache_cos = cache_cos.clone();
let cache_sin = cache_sin.clone();
move |q: Tensor<4, f32>, _k: Tensor<4, f32>| {
let cache_cos = cache_cos.clone();
let cache_sin = cache_sin.clone();
async move {
let device = q.device();
let cos_slice: Tensor<2, f32> = Tensor::new(&device, &cache_cos[1..4].to_vec());
let sin_slice: Tensor<2, f32> = Tensor::new(&device, &cache_sin[1..4].to_vec());
q.rope_normal_fused(&cos_slice, &sin_slice)
}
}
})
.compare_with(approx_compare::<4, f32>(1e-4))
.runs(3)
.await
.unwrap();
fusor_conformance::assert({
let cache_cos = cache_cos.clone();
let cache_sin = cache_sin.clone();
move |q: Tensor<4, f32>, k: Tensor<4, f32>| {
let cache_cos = cache_cos.clone();
let cache_sin = cache_sin.clone();
async move {
let device = q.device();
let cache = RopeCache::from_parts(
Tensor::new(&device, &cache_cos),
Tensor::new(&device, &cache_sin),
);
cache.forward(&q, &k, 1).1
}
}
})
.arg(gen_q.clone())
.arg(gen_k.clone())
.equal_to({
let cache_cos = cache_cos.clone();
let cache_sin = cache_sin.clone();
move |_q: Tensor<4, f32>, k: Tensor<4, f32>| {
let cache_cos = cache_cos.clone();
let cache_sin = cache_sin.clone();
async move {
let device = k.device();
let cos_slice: Tensor<2, f32> = Tensor::new(&device, &cache_cos[1..4].to_vec());
let sin_slice: Tensor<2, f32> = Tensor::new(&device, &cache_sin[1..4].to_vec());
k.rope_normal_fused(&cos_slice, &sin_slice)
}
}
})
.compare_with(approx_compare::<4, f32>(1e-4))
.runs(3)
.await
.unwrap();
fusor_conformance::assert({
let cache_cos = cache_cos.clone();
let cache_sin = cache_sin.clone();
move |q: Tensor<4, f32>, k: Tensor<4, f32>| {
let cache_cos = cache_cos.clone();
let cache_sin = cache_sin.clone();
async move {
let device = q.device();
let cache = RopeCache::from_parts(
Tensor::new(&device, &cache_cos),
Tensor::new(&device, &cache_sin),
);
cache.forward_interleaved(&q, &k, 1).0
}
}
})
.arg(gen_q.clone())
.arg(gen_k.clone())
.equal_to({
let cache_cos = cache_cos.clone();
let cache_sin = cache_sin.clone();
move |q: Tensor<4, f32>, _k: Tensor<4, f32>| {
let cache_cos = cache_cos.clone();
let cache_sin = cache_sin.clone();
async move {
let device = q.device();
let cos_slice: Tensor<2, f32> = Tensor::new(&device, &cache_cos[1..4].to_vec());
let sin_slice: Tensor<2, f32> = Tensor::new(&device, &cache_sin[1..4].to_vec());
q.rope_fused(&cos_slice, &sin_slice)
}
}
})
.compare_with(approx_compare::<4, f32>(1e-4))
.runs(3)
.await
.unwrap();
fusor_conformance::assert({
let cache_cos = cache_cos.clone();
let cache_sin = cache_sin.clone();
move |q: Tensor<4, f32>, k: Tensor<4, f32>| {
let cache_cos = cache_cos.clone();
let cache_sin = cache_sin.clone();
async move {
let device = q.device();
let cache = RopeCache::from_parts(
Tensor::new(&device, &cache_cos),
Tensor::new(&device, &cache_sin),
);
cache.forward_interleaved(&q, &k, 1).1
}
}
})
.arg(gen_q)
.arg(gen_k)
.equal_to({
let cache_cos = cache_cos.clone();
let cache_sin = cache_sin.clone();
move |_q: Tensor<4, f32>, k: Tensor<4, f32>| {
let cache_cos = cache_cos.clone();
let cache_sin = cache_sin.clone();
async move {
let device = k.device();
let cos_slice: Tensor<2, f32> = Tensor::new(&device, &cache_cos[1..4].to_vec());
let sin_slice: Tensor<2, f32> = Tensor::new(&device, &cache_sin[1..4].to_vec());
k.rope_fused(&cos_slice, &sin_slice)
}
}
})
.compare_with(approx_compare::<4, f32>(1e-4))
.runs(3)
.await
.unwrap();
fusor_conformance::assert(async |device: Device| {
RopeCache::new(4, 3, 10_000.0, &device)
.unwrap()
.cos()
.clone()
})
.arg(|device: &Device| device.clone())
.equal_to({
let expected_cos = rope_tables(4, 3, 10_000.0).0;
move |device: Device| {
let expected_cos = expected_cos.clone();
async move { Tensor::new(&device, &expected_cos) }
}
})
.compare_with(approx_compare::<2, f32>(1e-6))
.await
.unwrap();
fusor_conformance::assert(async |device: Device| {
RopeCache::new(4, 3, 10_000.0, &device)
.unwrap()
.sin()
.clone()
})
.arg(|device: &Device| device.clone())
.equal_to({
let expected_sin = rope_tables(4, 3, 10_000.0).1;
move |device: Device| {
let expected_sin = expected_sin.clone();
async move { Tensor::new(&device, &expected_sin) }
}
})
.compare_with(approx_compare::<2, f32>(1e-6))
.await
.unwrap();
}