//! CPU tensor operations with SIMD acceleration

use std::ops::Deref;

use pulp::Simd;
use pulp::bytemuck::Pod;

// Module declarations
mod cast;
mod comparison;
mod concrete_tensor;
mod conditional;
mod elementwise;
mod expr;
mod gather;
mod index;
mod map_layout;
mod matmul;
mod pairwise;
mod parallel;
mod quantized;
mod reduce;
mod scalar;
mod slice_assign;
mod tensor;

/// Maximum number of SIMD lanes supported for strided tensor gather operations.
/// This covers AVX-512 with 64 x i8 lanes. Current architectures don't exceed this,
/// but this constant provides a clear point for future updates if needed.
pub(crate) const MAX_SIMD_LANES: usize = 64;

// Re-export public types
pub use concrete_tensor::ConcreteTensor;
pub use elementwise::{
    Abs, Acos, Acosh, Asin, Asinh, Atan, Atanh, Cos, Cosh, Exp, Exp2, Log, Log2, Neg, Sin, Sinh,
    Sqrt, Tan, Tanh,
};
pub use expr::materialize_expr;
pub use map_layout::MapLayout;
pub use pairwise::{Add, Div, Mul, Rem, Sub};
pub use quantized::{Dequantize, QuantizedTensor};
pub use scalar::{AddScalar, Broadcast, DivScalar, MulScalar, SubScalar};
pub use tensor::{FloatOps, Scalar, Tensor};

// Re-export FromArray trait from fusor-types
pub use fusor_types::FromArray;

// Re-export Layout from fusor-types for public API
pub use fusor_types::Layout;

// Re-export aligned_vec types for use by dependent crates
pub use aligned_vec::ABox;
pub use aligned_vec::AVec;

// Re-export GGUF types for convenience
pub use fusor_gguf::{
    BlockQ4_0, BlockQ4K, BlockQ5_0, BlockQ5K, BlockQ6K, BlockQ8_0, GgmlType, GgufBlock,
};

// Re-export TensorSlice from fusor-types
pub use fusor_types::TensorSlice;

/// A buffer holding CPU tensor data as bytes.
///
/// This type is the CPU equivalent of fusor-core's `MappedBuffer` for GPU tensors.
/// It holds the raw bytes of tensor data and implements `Deref<Target = [u8]>`
/// to work with `TensorSlice`.
pub struct CpuMappedBuffer {
    bytes: Box<[u8]>,
}

impl CpuMappedBuffer {
    /// Create a new CpuMappedBuffer from a boxed byte slice.
    pub fn new(bytes: Box<[u8]>) -> Self {
        Self { bytes }
    }
}

impl Deref for CpuMappedBuffer {
    type Target = [u8];

    fn deref(&self) -> &Self::Target {
        &self.bytes
    }
}

// Re-export operation traits and markers for public bounds
pub use cast::CastTo;
pub use comparison::{Eq, EqOp, Gt, GtOp, Gte, GteOp, Lt, LtOp, Lte, LteOp, Ne, NeOp};
pub use conditional::IsNonZero;
pub use elementwise::{
    AbsOp, AcosOp, AcoshOp, AsinOp, AsinhOp, AtanOp, AtanhOp, CosOp, CoshOp, Exp2Op, ExpOp, Log2Op,
    LogOp, NegOp, SimdUnaryOp, SinOp, SinhOp, SqrtOp, TanOp, TanhOp,
};
pub use matmul::MatmulImpl;
pub use pairwise::{AddOp, DivOp, MulOp, RemOp, SimdBinaryOp, SubOp};
pub use reduce::{
    MaxOp, MinOp, ProdOp, SimdReduceOp, SumOp, layer_norm_last_dim_fused, softmax_last_dim_fused,
};

// Re-export internal types used by other modules
pub(crate) use concrete_tensor::IndexIterator;

// Trait for mapping tensor to its one-rank-smaller type (for axis reductions)
pub trait LastRankInner {
    type LastRank;
}

pub trait LastRank<const R: usize, T: SimdElement>:
    LastRankInner<LastRank = ConcreteTensor<T, R>>
{
}

impl<const R: usize, T: SimdElement, X> LastRank<R, T> for X where
    X: LastRankInner<LastRank = ConcreteTensor<T, R>>
{
}

// Macro to generate LastRankInner implementations for each rank
macro_rules! impl_last_rank {
    ($($R:literal),*) => {
        $(
            impl<T: SimdElement> LastRankInner for ConcreteTensor<T, $R> {
                type LastRank = ConcreteTensor<T, { $R - 1 }>;
            }
        )*
    };
}

// Generate for ranks 1-10
impl_last_rank!(1, 2, 3, 4, 5, 6, 7, 8, 9, 10);

// Trait for mapping tensor to its next-higher rank type (for unsqueeze)
pub trait NextRankInner {
    type NextRank;
}

pub trait NextRank<const R: usize, T: SimdElement>:
    NextRankInner<NextRank = ConcreteTensor<T, R>>
{
}

impl<const R: usize, T: SimdElement, X> NextRank<R, T> for X where
    X: NextRankInner<NextRank = ConcreteTensor<T, R>>
{
}

// Macro to generate NextRankInner implementations for each rank
macro_rules! impl_next_rank {
    ($($R:literal),*) => {
        $(
            impl<T: SimdElement> NextRankInner for ConcreteTensor<T, $R> {
                type NextRank = ConcreteTensor<T, { $R + 1 }>;
            }
        )*
    };
}

// Generate for ranks 0-9 (so next rank goes up to 10)
impl_next_rank!(0, 1, 2, 3, 4, 5, 6, 7, 8, 9);

// Trait for mapping tensor to a smaller rank (for squeeze, reduce)
pub trait SmallerRankInner<const DIFF: usize> {
    type SmallerRank;
}

pub trait SmallerRank<const R: usize, const DIFF: usize, T: SimdElement>:
    SmallerRankInner<DIFF, SmallerRank = ConcreteTensor<T, R>>
{
}

impl<const R: usize, const DIFF: usize, T: SimdElement, X> SmallerRank<R, DIFF, T> for X where
    X: SmallerRankInner<DIFF, SmallerRank = ConcreteTensor<T, R>>
{
}

// Macro to generate SmallerRankInner implementations
macro_rules! impl_smaller_rank {
    ($R:literal, $($DIFF:literal => $OUT:literal),*) => {
        $(
            impl<T: SimdElement> SmallerRankInner<$DIFF> for ConcreteTensor<T, $R> {
                type SmallerRank = ConcreteTensor<T, $OUT>;
            }
        )*
    };
}

// Generate smaller rank mappings
impl_smaller_rank!(1, 1 => 0);
impl_smaller_rank!(2, 1 => 1, 2 => 0);
impl_smaller_rank!(3, 1 => 2, 2 => 1, 3 => 0);
impl_smaller_rank!(4, 1 => 3, 2 => 2, 3 => 1, 4 => 0);
impl_smaller_rank!(5, 1 => 4, 2 => 3, 3 => 2, 4 => 1, 5 => 0);
impl_smaller_rank!(6, 1 => 5, 2 => 4, 3 => 3, 4 => 2, 5 => 1, 6 => 0);
impl_smaller_rank!(7, 1 => 6, 2 => 5, 3 => 4, 4 => 3, 5 => 2, 6 => 1, 7 => 0);
impl_smaller_rank!(8, 1 => 7, 2 => 6, 3 => 5, 4 => 4, 5 => 3, 6 => 2, 7 => 1, 8 => 0);
impl_smaller_rank!(9, 1 => 8, 2 => 7, 3 => 6, 4 => 5, 5 => 4, 6 => 3, 7 => 2, 8 => 1, 9 => 0);
impl_smaller_rank!(10, 1 => 9, 2 => 8, 3 => 7, 4 => 6, 5 => 5, 6 => 4, 7 => 3, 8 => 2, 9 => 1, 10 => 0);

// Trait for mapping tensor to a larger rank (for unsqueeze, expand)
pub trait LargerRankInner<const DIFF: usize> {
    type LargerRank;
}

pub trait LargerRank<const R: usize, const DIFF: usize, T: SimdElement>:
    LargerRankInner<DIFF, LargerRank = ConcreteTensor<T, R>>
{
}

impl<const R: usize, const DIFF: usize, T: SimdElement, X> LargerRank<R, DIFF, T> for X where
    X: LargerRankInner<DIFF, LargerRank = ConcreteTensor<T, R>>
{
}

// Macro to generate LargerRankInner implementations
macro_rules! impl_larger_rank {
    ($R:literal, $($DIFF:literal => $OUT:literal),*) => {
        $(
            impl<T: SimdElement> LargerRankInner<$DIFF> for ConcreteTensor<T, $R> {
                type LargerRank = ConcreteTensor<T, $OUT>;
            }
        )*
    };
}

// Generate larger rank mappings
impl_larger_rank!(0, 1 => 1, 2 => 2, 3 => 3, 4 => 4, 5 => 5, 6 => 6, 7 => 7, 8 => 8, 9 => 9, 10 => 10);
impl_larger_rank!(1, 1 => 2, 2 => 3, 3 => 4, 4 => 5, 5 => 6, 6 => 7, 7 => 8, 8 => 9, 9 => 10);
impl_larger_rank!(2, 1 => 3, 2 => 4, 3 => 5, 4 => 6, 5 => 7, 6 => 8, 7 => 9, 8 => 10);
impl_larger_rank!(3, 1 => 4, 2 => 5, 3 => 6, 4 => 7, 5 => 8, 6 => 9, 7 => 10);
impl_larger_rank!(4, 1 => 5, 2 => 6, 3 => 7, 4 => 8, 5 => 9, 6 => 10);
impl_larger_rank!(5, 1 => 6, 2 => 7, 3 => 8, 4 => 9, 5 => 10);
impl_larger_rank!(6, 1 => 7, 2 => 8, 3 => 9, 4 => 10);
impl_larger_rank!(7, 1 => 8, 2 => 9, 3 => 10);
impl_larger_rank!(8, 1 => 9, 2 => 10);
impl_larger_rank!(9, 1 => 10);

// Trait for mapping two tensors to their max rank (for broadcasting operations)
pub trait MaxRankInner {
    type MaxRank;
}

pub trait MaxRank<const R: usize, T: SimdElement>:
    MaxRankInner<MaxRank = ConcreteTensor<T, R>>
{
}

impl<const R: usize, T: SimdElement, X> MaxRank<R, T> for X where
    X: MaxRankInner<MaxRank = ConcreteTensor<T, R>>
{
}

// Same rank produces same rank
impl<const N: usize, T: SimdElement> MaxRankInner for (ConcreteTensor<T, N>, ConcreteTensor<T, N>) {
    type MaxRank = ConcreteTensor<T, N>;
}

// Macro to generate MaxRankInner implementations for different rank pairs
macro_rules! impl_max_rank {
    ($R1:literal, $R2:literal) => {
        impl<T: SimdElement> MaxRankInner for (ConcreteTensor<T, $R1>, ConcreteTensor<T, $R2>) {
            type MaxRank = ConcreteTensor<T, $R2>;
        }
        impl<T: SimdElement> MaxRankInner for (ConcreteTensor<T, $R2>, ConcreteTensor<T, $R1>) {
            type MaxRank = ConcreteTensor<T, $R2>;
        }
    };
}

// Generate MaxRank implementations for all rank combinations 0-10
impl_max_rank!(0, 1);
impl_max_rank!(0, 2);
impl_max_rank!(0, 3);
impl_max_rank!(0, 4);
impl_max_rank!(0, 5);
impl_max_rank!(0, 6);
impl_max_rank!(0, 7);
impl_max_rank!(0, 8);
impl_max_rank!(0, 9);
impl_max_rank!(0, 10);
impl_max_rank!(1, 2);
impl_max_rank!(1, 3);
impl_max_rank!(1, 4);
impl_max_rank!(1, 5);
impl_max_rank!(1, 6);
impl_max_rank!(1, 7);
impl_max_rank!(1, 8);
impl_max_rank!(1, 9);
impl_max_rank!(1, 10);
impl_max_rank!(2, 3);
impl_max_rank!(2, 4);
impl_max_rank!(2, 5);
impl_max_rank!(2, 6);
impl_max_rank!(2, 7);
impl_max_rank!(2, 8);
impl_max_rank!(2, 9);
impl_max_rank!(2, 10);
impl_max_rank!(3, 4);
impl_max_rank!(3, 5);
impl_max_rank!(3, 6);
impl_max_rank!(3, 7);
impl_max_rank!(3, 8);
impl_max_rank!(3, 9);
impl_max_rank!(3, 10);
impl_max_rank!(4, 5);
impl_max_rank!(4, 6);
impl_max_rank!(4, 7);
impl_max_rank!(4, 8);
impl_max_rank!(4, 9);
impl_max_rank!(4, 10);
impl_max_rank!(5, 6);
impl_max_rank!(5, 7);
impl_max_rank!(5, 8);
impl_max_rank!(5, 9);
impl_max_rank!(5, 10);
impl_max_rank!(6, 7);
impl_max_rank!(6, 8);
impl_max_rank!(6, 9);
impl_max_rank!(6, 10);
impl_max_rank!(7, 8);
impl_max_rank!(7, 9);
impl_max_rank!(7, 10);
impl_max_rank!(8, 9);
impl_max_rank!(8, 10);
impl_max_rank!(9, 10);

/// Trait for types that support scalar and SIMD evaluation without a rank parameter.
/// This is a supertrait of `TensorBacking` that allows rank-independent access.
pub trait LazyBacking: Sync {
    type Elem: SimdElement;

    /// Evaluate at a single scalar index.
    ///
    /// This is used for:
    /// - Tail elements that don't fill a complete SIMD vector
    /// - Non-contiguous tensor access patterns
    fn eval_scalar(&self, idx: usize) -> Self::Elem;

    /// Evaluate a SIMD chunk starting at the given base index.
    ///
    /// The returned SIMD vector contains multiple consecutive elements
    /// starting at `base_idx`. The caller must ensure that there are
    /// enough elements remaining to fill a complete SIMD vector.
    fn eval_simd<S: Simd>(&self, simd: S, base_idx: usize) -> <Self::Elem as SimdElement>::Simd<S>;
}

pub trait TensorBacking<const R: usize>: LazyBacking {
    fn layout(&self) -> Layout;
    fn to_concrete(&self) -> ConcreteTensor<Self::Elem, R>;
}

// Blanket implementation for references
impl<T: LazyBacking + Sync> LazyBacking for &T {
    type Elem = T::Elem;

    #[inline(always)]
    fn eval_scalar(&self, idx: usize) -> Self::Elem {
        (*self).eval_scalar(idx)
    }

    #[inline(always)]
    fn eval_simd<S: Simd>(&self, simd: S, base_idx: usize) -> <Self::Elem as SimdElement>::Simd<S> {
        (*self).eval_simd(simd, base_idx)
    }
}

impl<const R: usize, T: TensorBacking<R> + Sync> TensorBacking<R> for &T {
    fn layout(&self) -> Layout {
        (*self).layout()
    }

    fn to_concrete(&self) -> ConcreteTensor<Self::Elem, R> {
        (*self).to_concrete()
    }
}

pub trait ResolvedTensor<const R: usize>: TensorBacking<R> {
    fn data(&self) -> &ABox<[Self::Elem]>;
    fn data_mut(&mut self) -> &mut ABox<[Self::Elem]>;
}

/// Trait for SIMD element types with associated SIMD vector type
pub trait SimdElement: Sized + Copy + Default + Pod + Sync + Send {
    /// The SIMD vector type for this element (GAT)
    type Simd<S: Simd>: Copy;

    /// Convert slice to SIMD vectors + remainder
    fn as_simd<S: Simd>(slice: &[Self]) -> (&[Self::Simd<S>], &[Self]);
    fn as_mut_simd<S: Simd>(slice: &mut [Self]) -> (&mut [Self::Simd<S>], &mut [Self]);

    /// Broadcast a scalar value to all lanes of a SIMD vector
    fn splat<S: Simd>(simd: S, value: Self) -> Self::Simd<S>;

    /// Gather elements from the slice at the specified indices using SIMD.
    ///
    /// # Safety
    /// All indices must be valid indices into the slice.
    ///
    /// # Arguments
    /// * `simd` - The SIMD context
    /// * `slice` - The source data slice
    /// * `indices` - Array of indices to gather from
    /// * `lane_count` - Number of SIMD lanes to fill
    ///
    /// Uses hardware SIMD gather instructions (AVX2, AVX-512) when available,
    /// falling back to scalar loads on other architectures.
    unsafe fn gather_unchecked<S: Simd>(
        simd: S,
        slice: &[Self],
        indices: &[usize],
        lane_count: usize,
    ) -> Self::Simd<S>;
}

macro_rules! impl_simd_element {
    ($elem:ty, $simd_ty:ident, $as_simd:ident, $as_mut_simd:ident, $splat:ident) => {
        impl SimdElement for $elem {
            type Simd<S: Simd> = S::$simd_ty;

            #[inline(always)]
            fn as_simd<S: Simd>(slice: &[Self]) -> (&[S::$simd_ty], &[Self]) {
                S::$as_simd(slice)
            }

            #[inline(always)]
            fn as_mut_simd<S: Simd>(slice: &mut [Self]) -> (&mut [S::$simd_ty], &mut [Self]) {
                S::$as_mut_simd(slice)
            }

            #[inline(always)]
            fn splat<S: Simd>(simd: S, value: Self) -> S::$simd_ty {
                simd.$splat(value)
            }

            #[inline(always)]
            unsafe fn gather_unchecked<S: Simd>(
                simd: S,
                slice: &[Self],
                indices: &[usize],
                lane_count: usize,
            ) -> Self::Simd<S> {
                // SAFETY: Caller guarantees all indices are valid
                unsafe { gather::gather_impl::<Self, S>(simd, slice, indices, lane_count) }
            }
        }
    };
}

impl_simd_element!(f32, f32s, as_simd_f32s, as_mut_simd_f32s, splat_f32s);
impl_simd_element!(f64, f64s, as_simd_f64s, as_mut_simd_f64s, splat_f64s);
impl_simd_element!(i8, i8s, as_simd_i8s, as_mut_simd_i8s, splat_i8s);
impl_simd_element!(i16, i16s, as_simd_i16s, as_mut_simd_i16s, splat_i16s);
impl_simd_element!(i32, i32s, as_simd_i32s, as_mut_simd_i32s, splat_i32s);
impl_simd_element!(i64, i64s, as_simd_i64s, as_mut_simd_i64s, splat_i64s);
impl_simd_element!(u8, u8s, as_simd_u8s, as_mut_simd_u8s, splat_u8s);
impl_simd_element!(u16, u16s, as_simd_u16s, as_mut_simd_u16s, splat_u16s);
impl_simd_element!(u32, u32s, as_simd_u32s, as_mut_simd_u32s, splat_u32s);
impl_simd_element!(u64, u64s, as_simd_u64s, as_mut_simd_u64s, splat_u64s);

/// Wrapper type for f16 "SIMD" operations.
///
/// Since pulp doesn't have native f16 SIMD support, we use a scalar fallback.
/// This type wraps a single f16 value and presents itself as a "SIMD vector"
/// with one lane. Operations fall back to scalar code.
#[derive(Copy, Clone)]
pub struct F16Scalar(pub half::f16);

impl SimdElement for half::f16 {
    /// The "SIMD" type for f16 is just a scalar wrapper since pulp lacks f16 SIMD.
    type Simd<S: Simd> = F16Scalar;

    #[inline(always)]
    fn as_simd<S: Simd>(_slice: &[Self]) -> (&[Self::Simd<S>], &[Self]) {
        // No SIMD for f16, return empty SIMD slice and all elements as remainder
        (&[], _slice)
    }

    #[inline(always)]
    fn as_mut_simd<S: Simd>(_slice: &mut [Self]) -> (&mut [Self::Simd<S>], &mut [Self]) {
        // No SIMD for f16, return empty SIMD slice and all elements as remainder
        (&mut [], _slice)
    }

    #[inline(always)]
    fn splat<S: Simd>(_simd: S, value: Self) -> Self::Simd<S> {
        F16Scalar(value)
    }

    #[inline(always)]
    unsafe fn gather_unchecked<S: Simd>(
        _simd: S,
        slice: &[Self],
        indices: &[usize],
        _lane_count: usize,
    ) -> Self::Simd<S> {
        // Scalar fallback: just return the first element
        F16Scalar(slice[indices[0]])
    }
}