use std::simd::prelude::*;
use std::simd::Simd;
use crate::{Reg, Reg32, UReg};
#[derive(Default, Copy, Clone)]
pub struct SubblockStats {
pub sum: [Reg; 3],
pub avg: [Reg; 3],
}
#[inline]
pub fn fast_div_255_round(x: Reg) -> Reg {
let r = x + Simd::splat(128);
(r + ((r + Simd::splat(257)) >> 8)) >> 8
}
#[inline]
pub fn prepare_averages(data: &[[[Reg; 3]; 4]; 4]) -> [SubblockStats; 4] {
let mut sum_2x2 = [[Reg::default(); 3]; 4];
for y in 0..2 {
for x in 0..2 {
for ch in 0..3 {
sum_2x2[y * 2 + x][ch] = data[y * 2][x * 2][ch]
+ data[y * 2 + 1][x * 2][ch]
+ data[y * 2][x * 2 + 1][ch]
+ data[y * 2 + 1][x * 2 + 1][ch];
}
}
}
let mut out = [SubblockStats::default(); 4];
for (i, (s0, s1)) in [(0, 1), (2, 3), (0, 2), (1, 3)].into_iter().enumerate() {
for ch in 0..3 {
out[i].sum[ch] = sum_2x2[s0][ch] + sum_2x2[s1][ch];
out[i].avg[ch] = (out[i].sum[ch] + Simd::splat(4)) >> 3;
}
}
out
}
struct QuantResultWithErr {
lo: UReg,
hi: UReg,
scaled0: [Reg; 3],
scaled1: [Reg; 3],
err: Reg32,
}
pub struct QuantResult {
pub lo: UReg,
pub hi: UReg,
pub scaled0: [Reg; 3],
pub scaled1: [Reg; 3],
}
#[inline]
fn quant_444(
avg0: [Reg; 3],
avg1: [Reg; 3],
sum0: [Reg; 3],
sum1: [Reg; 3],
flip: bool,
) -> QuantResultWithErr {
#[inline]
fn quant(avg: [Reg; 3]) -> [Reg; 3] {
avg.map(|x| fast_div_255_round(x * Simd::splat(15)))
}
#[inline]
fn scale(q: [Reg; 3]) -> [Reg; 3] {
q.map(|x| (x << 4) | x)
}
#[inline]
fn encode(q0: [Reg; 3], q1: [Reg; 3], flip: bool) -> (UReg, UReg) {
let flip = if flip { UReg::splat(1) } else { UReg::splat(0) };
let diff = UReg::splat(0);
let base1_b = q1[2].cast::<u16>() << 8;
let base0_b = q0[2].cast::<u16>() << 12;
let lo = flip | diff | base1_b | base0_b;
let base1_g = q1[1].cast::<u16>();
let base0_g = q0[1].cast::<u16>() << 4;
let base1_r = q1[0].cast::<u16>() << 8;
let base0_r = q0[0].cast::<u16>() << 12;
let hi = base1_g | base0_g | base1_r | base0_r;
(lo, hi)
}
let q0 = quant(avg0);
let q1 = quant(avg1);
let scaled0 = scale(q0);
let scaled1 = scale(q1);
let err = eval_quant_err(scaled0, scaled1, sum0, sum1);
let (lo, hi) = encode(q0, q1, flip);
QuantResultWithErr { lo, hi, scaled0, scaled1, err }
}
#[inline]
fn quant_555(
avg0: [Reg; 3],
avg1: [Reg; 3],
sum0: [Reg; 3],
sum1: [Reg; 3],
flip: bool,
) -> QuantResultWithErr {
#[inline]
fn quant(avg: [Reg; 3]) -> [Reg; 3] {
avg.map(|x| fast_div_255_round(x * Simd::splat(31)))
}
#[inline]
fn scale(q: [Reg; 3]) -> [Reg; 3] {
q.map(|x| (x << 3) | (x >> 2))
}
#[inline]
fn encode(q0: [Reg; 3], delta: [Reg; 3], flip: bool) -> (UReg, UReg) {
#[inline]
fn encode_delta(d: Reg) -> UReg {
d.cast::<u16>() & Simd::splat(0b111)
}
let flip = if flip { UReg::splat(1) } else { UReg::splat(0) };
let diff = UReg::splat(1 << 1);
let delta_b = encode_delta(delta[2]) << 8;
let base_b = q0[2].cast::<u16>() << 11;
let lo = flip | diff | delta_b | base_b;
let delta_g = encode_delta(delta[1]);
let base_g = q0[1].cast::<u16>() << 3;
let delta_r = encode_delta(delta[0]) << 8;
let base_r = q0[0].cast::<u16>() << 11;
let hi = delta_g | base_g | delta_r | base_r;
(lo, hi)
}
let q0 = quant(avg0);
let q1 = quant(avg1);
let delta = [0, 1, 2].map(|i| (q1[i] - q0[i]).simd_clamp(Simd::splat(-4), Simd::splat(3)));
let q1 = [0, 1, 2].map(|i| q0[i] + delta[i]);
let scaled0 = scale(q0);
let scaled1 = scale(q1);
let err = eval_quant_err(scaled0, scaled1, sum0, sum1);
let (lo, hi) = encode(q0, delta, flip);
QuantResultWithErr { lo, hi, scaled0, scaled1, err }
}
#[inline]
fn eval_quant_err(q0: [Reg; 3], q1: [Reg; 3], sum0: [Reg; 3], sum1: [Reg; 3]) -> Reg32 {
(0..3).fold(Reg32::splat(0), |mut acc, i| {
let q0 = q0[i].cast::<i32>();
let q1 = q1[i].cast::<i32>();
let sum0 = sum0[i].cast::<i32>();
let sum1 = sum1[i].cast::<i32>();
acc += q0 * ((q0 << 2) - sum0);
acc += q1 * ((q1 << 2) - sum1);
acc
})
}
#[inline]
fn quantize_endpoint_pairs(
avg0: [Reg; 3],
avg1: [Reg; 3],
sum0: [Reg; 3],
sum1: [Reg; 3],
flip: bool,
) -> QuantResultWithErr {
let q444 = quant_444(avg0, avg1, sum0, sum1, flip);
let q555 = quant_555(avg0, avg1, sum0, sum1, flip);
let prefer555_32 = q555.err.simd_lt(q444.err);
let prefer555 = prefer555_32.cast::<i16>();
QuantResultWithErr {
lo: prefer555.select(q555.lo, q444.lo),
hi: prefer555.select(q555.hi, q444.hi),
scaled0: [0, 1, 2].map(|i| prefer555.select(q555.scaled0[i], q444.scaled0[i])),
scaled1: [0, 1, 2].map(|i| prefer555.select(q555.scaled1[i], q444.scaled1[i])),
err: prefer555_32.select(q555.err, q444.err),
}
}
#[inline]
pub fn quantize_averages(data: &[[[Reg; 3]; 4]; 4]) -> QuantResult {
let stats = prepare_averages(&data);
let flip =
quantize_endpoint_pairs(stats[0].avg, stats[1].avg, stats[0].sum, stats[1].sum, true);
let no_flip =
quantize_endpoint_pairs(stats[2].avg, stats[3].avg, stats[2].sum, stats[3].sum, false);
let prefer_flip = flip.err.simd_lt(no_flip.err).cast::<i16>();
QuantResult {
lo: prefer_flip.select(flip.lo, no_flip.lo),
hi: prefer_flip.select(flip.hi, no_flip.hi),
scaled0: [0, 1, 2].map(|i| prefer_flip.select(flip.scaled0[i], no_flip.scaled0[i])),
scaled1: [0, 1, 2].map(|i| prefer_flip.select(flip.scaled1[i], no_flip.scaled1[i])),
}
}