//! Expression template system for lazy operation fusion
//!
//! This module provides types that enable automatic fusion of elementwise
//! operations. When multiple operations are chained (e.g.,
//! `x.mul_ref(&y).add_ref(&z).sqrt_ref()`), they are evaluated in a single
//! SIMD loop instead of multiple passes.

use pulp::{Arch, Simd, WithSimd};

use crate::{ConcreteTensor, SimdElement, TensorBacking};

/// Helper to get SIMD lane count for a given element type and SIMD architecture
#[inline(always)]
fn simd_lane_count<E: SimdElement, S: Simd>() -> usize {
    std::mem::size_of::<E::Simd<S>>() / std::mem::size_of::<E>()
}

/// Evaluates a tensor backing into output memory using SIMD.
///
/// This is the core evaluation loop that fuses all operations in an
/// expression tree into a single pass over the data.
struct TensorEvaluator<'a, T: TensorBacking<R>, const R: usize> {
    tensor: &'a T,
    out: &'a mut [T::Elem],
    /// Logical offset into the tensor for this chunk
    base_offset: usize,
}

impl<T: TensorBacking<R>, const R: usize> WithSimd for TensorEvaluator<'_, T, R> {
    type Output = ();

    #[inline(always)]
    fn with_simd<S: Simd>(self, simd: S) -> Self::Output {
        let (simd_out, scalar_out) = T::Elem::as_mut_simd::<S>(self.out);

        // Main SIMD loop
        let lane_count = simd_lane_count::<T::Elem, S>();
        for (i, out_vec) in simd_out.iter_mut().enumerate() {
            let base_idx = self.base_offset + i * lane_count;
            *out_vec = self.tensor.eval_simd(simd, base_idx);
        }

        // Scalar tail - handles remaining elements
        let scalar_start = simd_out.len() * lane_count;
        for (i, out_val) in scalar_out.iter_mut().enumerate() {
            *out_val = self.tensor.eval_scalar(self.base_offset + scalar_start + i);
        }
    }
}

/// Minimum number of elements before parallelization is used.
/// Below this threshold, the overhead of thread spawning isn't worth it.
/// Note: Set very high because std::thread::scope has significant overhead
/// (spawns new threads each time). For repeated operations like in transformer
/// layers, the thread spawn/join cost dominates. A thread pool would be better
/// but for now we avoid parallelization for typical LLM tensor sizes.
const PARALLEL_THRESHOLD: usize = 16_777_216; // 16M elements (~64MB for f32)

/// Materialize a tensor backing into a new ConcreteTensor.
///
/// This evaluates the entire expression tree in a single fused SIMD loop,
/// writing the results into a newly allocated tensor.
///
/// For large tensors (>4096 elements), the work is split among multiple
/// threads using `std::thread::scope` for structured parallelism.
#[inline]
#[must_use = "this allocates a new tensor; discarding it wastes computation"]
pub fn materialize_expr<T: TensorBacking<R> + Sync, const R: usize>(
    tensor: &T,
    shape: [usize; R],
) -> ConcreteTensor<T::Elem, R> {
    let mut output = ConcreteTensor::<T::Elem, R>::zeros(shape);
    let total_elements = output.backing().len();

    let n_threads = crate::parallel::num_threads();

    // Use parallel execution for large tensors
    if total_elements >= PARALLEL_THRESHOLD && n_threads > 1 {
        // Align chunk boundaries to MAX_SIMD_LANES to maintain SIMD alignment across chunks
        let raw_chunk = total_elements.div_ceil(n_threads);
        let chunk_size = raw_chunk.next_multiple_of(crate::MAX_SIMD_LANES);
        let out_slice: &mut [T::Elem] = output.backing_mut();

        std::thread::scope(|scope| {
            let mut remaining_slice = out_slice;
            let mut offset = 0;

            for thread_id in 0..n_threads {
                if remaining_slice.is_empty() {
                    break;
                }

                let remaining = total_elements - offset;
                let this_size = if thread_id == n_threads - 1 {
                    remaining
                } else {
                    chunk_size.min(remaining)
                };

                let (thread_slice, rest) = remaining_slice.split_at_mut(this_size);
                remaining_slice = rest;
                let current_offset = offset;
                offset += this_size;

                let evaluator = TensorEvaluator {
                    tensor,
                    out: thread_slice,
                    base_offset: current_offset,
                };

                scope.spawn(move || {
                    Arch::new().dispatch(evaluator);
                });
            }
        });
    } else {
        // Small tensor: single-threaded execution
        let out_slice: &mut [T::Elem] = output.backing_mut();
        Arch::new().dispatch(TensorEvaluator {
            tensor,
            out: out_slice,
            base_offset: 0,
        });
    }

    output
}

/// Convert a linear index to logical indices for a given shape.
///
/// This is used for strided tensor access where we need to map
/// a flat iteration index to multi-dimensional tensor coordinates.
#[inline]
pub(crate) fn linear_to_indices<const R: usize>(
    mut linear: usize,
    shape: &[usize; R],
) -> [usize; R] {
    let mut indices = [0usize; R];

    // Work backwards through dimensions (row-major order)
    for i in (0..R).rev() {
        indices[i] = linear % shape[i];
        linear /= shape[i];
    }

    indices
}