//! Support for GGUF quantized tensors
//!
//! This module provides `QuantizedTensor` for storing and operating on
//! quantized data from GGUF files. It supports:
//! - Multiple quantization types (Q4_0, Q5_0, Q8_0, Q4K, Q6K)
//! - Eager full dequantization to f32
//! - Lazy dequantization via the `Dequantize` expression type
//! - Efficient block-by-block matrix multiplication

use aligned_vec::{ABox, AVec};
use bytemuck::Pod;
use fusor_gguf::GgufBlock;
use pulp::Simd;

use fusor_types::Layout;

use crate::expr::materialize_expr;
use crate::reduce::{SimdReduceOp, SumOp};
use crate::{ConcreteTensor, MAX_SIMD_LANES, ResolvedTensor, SimdElement, TensorBacking};

/// A tensor storing quantized blocks.
///
/// `QuantizedTensor<B>` stores data in quantized block format where `B` is
/// the block type (e.g., `BlockQ4_0`). The rank is dynamic at runtime.
///
/// The innermost dimension must be a multiple of the block size. For example,
/// a [3, 256] tensor with Q4_0 quantization (block size 32) stores 3 rows of
/// 8 blocks each.
#[derive(Clone)]
pub struct QuantizedTensor<B: GgufBlock> {
    /// The logical shape in elements (not blocks)
    element_shape: Box<[usize]>,
    /// The quantized blocks stored in row-major order
    blocks: ABox<[B]>,
}

impl<B: GgufBlock> QuantizedTensor<B> {
    /// Create a quantized tensor from pre-existing blocks.
    ///
    /// # Arguments
    /// * `element_shape` - The logical shape in elements (not blocks).
    ///   The innermost dimension must be a multiple of `B::BLOCK_SIZE`.
    /// * `blocks` - The quantized blocks in row-major order.
    ///
    /// # Panics
    /// Panics if:
    /// - The innermost dimension is not a multiple of the block size
    /// - The number of blocks doesn't match the shape
    pub fn from_blocks(element_shape: impl Into<Box<[usize]>>, blocks: ABox<[B]>) -> Self {
        let element_shape = element_shape.into();
        let rank = element_shape.len();
        assert!(rank > 0, "Tensor must have at least rank 1");
        let inner_dim = element_shape[rank - 1];
        assert!(
            inner_dim % B::BLOCK_SIZE == 0,
            "Innermost dimension ({}) must be a multiple of block size ({})",
            inner_dim,
            B::BLOCK_SIZE
        );

        let expected_blocks = Self::compute_block_count(&element_shape);
        assert_eq!(
            blocks.len(),
            expected_blocks,
            "Expected {} blocks for shape {:?}, got {}",
            expected_blocks,
            element_shape,
            blocks.len()
        );

        Self {
            element_shape,
            blocks,
        }
    }

    /// Create a quantized tensor from raw bytes.
    ///
    /// This interprets the bytes as a slice of blocks using bytemuck.
    ///
    /// # Arguments
    /// * `element_shape` - The logical shape in elements (not blocks).
    /// * `bytes` - Raw bytes that will be cast to blocks.
    ///
    /// # Panics
    /// Panics if:
    /// - The bytes length is not a multiple of the block size
    /// - The innermost dimension is not a multiple of the block size
    /// - The number of blocks doesn't match the shape
    pub fn from_raw_bytes(element_shape: impl Into<Box<[usize]>>, bytes: &[u8]) -> Self {
        let blocks_slice: &[B] = pulp::bytemuck::cast_slice(bytes);
        let mut vec: AVec<B> = AVec::with_capacity(64, blocks_slice.len());
        vec.extend_from_slice(blocks_slice);
        Self::from_blocks(element_shape, vec.into_boxed_slice())
    }

    /// Compute the number of blocks needed for a given element shape.
    fn compute_block_count(element_shape: &[usize]) -> usize {
        let total_elements: usize = element_shape.iter().product();
        total_elements / B::BLOCK_SIZE
    }

    /// Returns the logical element shape (not block shape).
    pub fn element_shape(&self) -> &[usize] {
        &self.element_shape
    }

    /// Returns the total number of logical elements.
    pub fn element_count(&self) -> usize {
        self.element_shape.iter().product()
    }

    /// Returns the number of blocks.
    pub fn block_count(&self) -> usize {
        self.blocks.len()
    }

    /// Returns a reference to the underlying blocks.
    pub fn blocks(&self) -> &[B] {
        &self.blocks
    }

    /// Eagerly dequantize the entire tensor to f32.
    ///
    /// This allocates a new `ConcreteTensor<f32, R>` and dequantizes all blocks.
    /// For large tensors, consider using `dequantize_lazy()` instead.
    ///
    /// # Panics
    /// Panics if the tensor's rank doesn't match R.
    pub fn dequantize<const R: usize>(&self) -> ConcreteTensor<f32, R> {
        let shape: [usize; R] = self
            .element_shape
            .as_ref()
            .try_into()
            .expect("Shape length mismatch in dequantize");
        let layout = fusor_types::Layout::contiguous(&shape);
        let n = layout.num_elements();
        let mut vec: AVec<f32> = AVec::with_capacity(64, n);

        for block in self.blocks.iter() {
            let dequantized = block.dequantize();
            vec.extend_from_slice(dequantized.as_ref());
        }

        ConcreteTensor::from_parts(layout, vec.into_boxed_slice())
    }

    /// Create a lazy dequantization expression.
    ///
    /// This returns a `Dequantize` expression that implements `Expr`,
    /// allowing it to be composed with other operations before materialization.
    ///
    /// # Panics
    /// Panics if the tensor's rank doesn't match R.
    pub fn dequantize_lazy<const R: usize>(&self) -> Dequantize<'_, B, R> {
        assert_eq!(
            self.element_shape.len(),
            R,
            "Tensor rank {} doesn't match expected rank {}",
            self.element_shape.len(),
            R
        );
        Dequantize { source: self }
    }
}

/// Lazy dequantization expression.
///
/// This implements `Expr` for lazy evaluation of dequantized values.
/// Instead of dequantizing the entire tensor upfront, values are
/// dequantized on-demand during expression evaluation.
pub struct Dequantize<'a, B: GgufBlock, const R: usize> {
    source: &'a QuantizedTensor<B>,
}

impl<B: GgufBlock, const R: usize> crate::LazyBacking for Dequantize<'_, B, R>
where
    B::Dequantized: AsRef<[f32]>,
{
    type Elem = f32;

    #[inline(always)]
    fn eval_scalar(&self, idx: usize) -> f32 {
        let block_idx = idx / B::BLOCK_SIZE;
        let elem_idx = idx % B::BLOCK_SIZE;
        self.source.blocks[block_idx].dequantize().as_ref()[elem_idx]
    }

    #[inline(always)]
    fn eval_simd<S: Simd>(&self, _simd: S, base_idx: usize) -> <f32 as SimdElement>::Simd<S> {
        // Block boundaries don't typically align with SIMD lanes,
        // so we fall back to scalar gathering
        let lane_count =
            std::mem::size_of::<<f32 as SimdElement>::Simd<S>>() / std::mem::size_of::<f32>();
        let mut temp = [0.0f32; MAX_SIMD_LANES];
        for (i, temp_elem) in temp.iter_mut().enumerate().take(lane_count) {
            *temp_elem = self.eval_scalar(base_idx + i);
        }
        let (simd_vec, _) = f32::as_simd::<S>(&temp[..lane_count]);
        simd_vec[0]
    }
}

impl<B: GgufBlock, const R: usize> TensorBacking<R> for Dequantize<'_, B, R>
where
    B::Dequantized: AsRef<[f32]>,
{
    fn layout(&self) -> Layout {
        // The layout of the dequantized tensor matches the source tensor's element shape
        let shape: [usize; R] = self
            .source
            .element_shape
            .as_ref()
            .try_into()
            .expect("Shape length mismatch in Dequantize::layout");
        Layout::contiguous(&shape)
    }

    fn to_concrete(&self) -> ConcreteTensor<f32, R> {
        let shape: [usize; R] = self
            .source
            .element_shape
            .as_ref()
            .try_into()
            .expect("Shape length mismatch in Dequantize::to_concrete");
        materialize_expr(self, shape)
    }
}

/// Matrix multiplication with a quantized RHS.
///
/// Computes `self @ rhs` where `self` is an f32 tensor and `rhs` is quantized.
/// This processes blocks one at a time to avoid the memory cost of full dequantization.
/// Supports batched inputs: `[batch_dims..., M, K] @ [K, N] -> [batch_dims..., M, N]`
impl<const R: usize> ConcreteTensor<f32, R> {
    /// Matrix multiplication: self ([batch_dims..., M, K]) @ rhs (K x N) -> ([batch_dims..., M, N])
    ///
    /// This is optimized for the case where the RHS (weights) are quantized.
    /// Instead of dequantizing the entire RHS matrix, it processes block-by-block
    /// with SIMD acceleration.
    ///
    /// The RHS must be 2D (K x N), while the LHS can have arbitrary batch dimensions.
    ///
    /// # Panics
    /// Panics if rhs is not 2D.
    pub fn q_mat_mul<B: GgufBlock + Sync>(&self, rhs: &QuantizedTensor<B>) -> ConcreteTensor<f32, R>
    where
        B::Dequantized: AsRef<[f32]>,
        B::ActivationBlock: Pod + Send + Sync,
    {
        const { assert!(R >= 2, "q_mat_mul requires at least 2 dimensions") };

        let rhs_shape = rhs.element_shape();
        assert_eq!(
            rhs_shape.len(),
            2,
            "q_mat_mul requires 2D weight tensor, got {}D",
            rhs_shape.len()
        );

        let lhs_shape = self.layout().shape();
        let m = lhs_shape[R - 2];
        let k = lhs_shape[R - 1];
        // Weight is stored as [out_features, in_features] to match GPU convention
        let n = rhs_shape[0]; // out_features
        let k2 = rhs_shape[1]; // in_features

        assert_eq!(
            k, k2,
            "Matrix dimension mismatch: lhs columns ({}) != weight in_features ({})",
            k, k2
        );

        // Output shape: preserve batch dims, replace last two with [M, N]
        let mut out_shape: [usize; R] = [0; R];
        out_shape.copy_from_slice(lhs_shape);
        out_shape[R - 1] = n;

        let mut output = ConcreteTensor::<f32, R>::zeros(out_shape);

        // Compute batch size (product of all dims except last 2)
        let batch_size: usize = if R > 2 {
            lhs_shape[..R - 2].iter().product()
        } else {
            1
        };

        let lhs_matrix_size = m * k;
        let out_matrix_size = m * n;

        // Weight is [N, K], so blocks per row of weight = K / BLOCK_SIZE
        let blocks_per_weight_row = k / B::BLOCK_SIZE;

        let lhs_contiguous = self.layout().is_contiguous();

        if lhs_contiguous {
            // Fast path: LHS is contiguous
            let lhs_data = self.data();
            let out_data = output.data_mut();

            for b in 0..batch_size {
                let lhs_slice = &lhs_data[b * lhs_matrix_size..(b + 1) * lhs_matrix_size];
                let out_slice = &mut out_data[b * out_matrix_size..(b + 1) * out_matrix_size];

                pulp::Arch::new().dispatch(QMatmulSimd {
                    lhs_data: lhs_slice,
                    rhs_blocks: rhs.blocks(),
                    out_data: out_slice,
                    m,
                    k,
                    n,
                    blocks_per_weight_row,
                    _phantom: std::marker::PhantomData::<B>,
                });
            }
        } else {
            // Slow path: LHS is not contiguous, need to extract each batch to contiguous memory
            let batch_dims = &lhs_shape[..R - 2];
            let mut batch_indices = vec![0usize; R - 2];

            for b in 0..batch_size {
                // Extract this batch's matrix to contiguous memory
                let mut lhs_batch = vec![0.0f32; lhs_matrix_size];
                for i in 0..m {
                    for l in 0..k {
                        let mut lhs_idx_arr = [0usize; R];
                        for (idx, &bi) in batch_indices.iter().enumerate() {
                            lhs_idx_arr[idx] = bi;
                        }
                        lhs_idx_arr[R - 2] = i;
                        lhs_idx_arr[R - 1] = l;
                        let lhs_idx = self.layout().linear_index(&lhs_idx_arr);
                        lhs_batch[i * k + l] = self.data()[lhs_idx];
                    }
                }

                let out_slice =
                    &mut output.data_mut()[b * out_matrix_size..(b + 1) * out_matrix_size];

                pulp::Arch::new().dispatch(QMatmulSimd {
                    lhs_data: &lhs_batch,
                    rhs_blocks: rhs.blocks(),
                    out_data: out_slice,
                    m,
                    k,
                    n,
                    blocks_per_weight_row,
                    _phantom: std::marker::PhantomData::<B>,
                });

                // Increment batch indices (like a multi-digit counter)
                for d in (0..batch_indices.len()).rev() {
                    batch_indices[d] += 1;
                    if batch_indices[d] < batch_dims[d] {
                        break;
                    }
                    batch_indices[d] = 0;
                }
            }
        }

        output
    }
}

/// SIMD-accelerated quantized matmul kernel
struct QMatmulSimd<'a, B: GgufBlock> {
    lhs_data: &'a [f32],
    rhs_blocks: &'a [B],
    out_data: &'a mut [f32],
    m: usize,
    k: usize,
    n: usize,
    /// Number of blocks per row of the weight matrix [N, K]
    blocks_per_weight_row: usize,
    _phantom: std::marker::PhantomData<B>,
}

impl<B: GgufBlock + Sync> pulp::WithSimd for QMatmulSimd<'_, B>
where
    B::Dequantized: AsRef<[f32]>,
    B::ActivationBlock: Pod + Send + Sync,
{
    type Output = ();

    #[inline(always)]
    fn with_simd<S: Simd>(self, simd: S) -> Self::Output {
        let Self {
            lhs_data,
            rhs_blocks,
            out_data,
            m,
            k,
            n,
            blocks_per_weight_row,
            ..
        } = self;

        // Use f32 dequantize path: dequantize weight blocks to f32 and use SIMD mul_add.
        // This avoids quantizing activations (which introduces compounding error across layers).
        // Special fast path for m=1 (common inference case): parallelize over output columns
        if m == 1 {
            let n_threads = crate::parallel::num_threads();

            // For small n or single-threaded, don't parallelize
            if n < 64 || n_threads == 1 {
                process_row_simd_tiled::<B, S>(
                    simd,
                    lhs_data,
                    rhs_blocks,
                    out_data,
                    n,
                    blocks_per_weight_row,
                );
            } else {
                // Parallelize over output column chunks using scoped threads
                const CHUNK_SIZE: usize = 32;
                let total_chunks = n.div_ceil(CHUNK_SIZE);
                let chunks_per_thread = total_chunks.div_ceil(n_threads);
                let elements_per_thread = chunks_per_thread * CHUNK_SIZE;

                std::thread::scope(|scope| {
                    let mut remaining = out_data;
                    let mut start_n = 0;

                    for thread_id in 0..n_threads {
                        if remaining.is_empty() {
                            break;
                        }

                        let this_size = if thread_id == n_threads - 1 {
                            remaining.len()
                        } else {
                            elements_per_thread.min(remaining.len())
                        };

                        let (thread_chunk, rest) = remaining.split_at_mut(this_size);
                        remaining = rest;
                        let thread_start_n = start_n;
                        start_n += this_size;

                        scope.spawn(move || {
                            // Process each CHUNK_SIZE piece within this thread
                            for (i, out_chunk) in thread_chunk.chunks_mut(CHUNK_SIZE).enumerate() {
                                let chunk_start = thread_start_n + i * CHUNK_SIZE;
                                process_row_simd_range::<B, S>(
                                    simd,
                                    lhs_data,
                                    rhs_blocks,
                                    out_chunk,
                                    chunk_start,
                                    out_chunk.len(),
                                    blocks_per_weight_row,
                                );
                            }
                        });
                    }
                });
            }
        } else if m >= 4 {
            let n_threads = crate::parallel::num_threads();

            if n_threads == 1 {
                // Sequential processing
                for i in 0..m {
                    let lhs_row = &lhs_data[i * k..(i + 1) * k];
                    let out_row = &mut out_data[i * n..(i + 1) * n];
                    process_row_simd_tiled::<B, S>(
                        simd,
                        lhs_row,
                        rhs_blocks,
                        out_row,
                        n,
                        blocks_per_weight_row,
                    );
                }
            } else {
                // Process rows in parallel using scoped threads
                let rows_per_thread = m.div_ceil(n_threads);

                std::thread::scope(|scope| {
                    let mut remaining_out = out_data;
                    let mut row_offset = 0;

                    for thread_id in 0..n_threads {
                        if remaining_out.is_empty() {
                            break;
                        }

                        let this_rows = if thread_id == n_threads - 1 {
                            m - row_offset
                        } else {
                            rows_per_thread.min(m - row_offset)
                        };

                        let this_size = this_rows * n;
                        let (thread_out, rest) = remaining_out.split_at_mut(this_size);
                        remaining_out = rest;
                        let thread_row_offset = row_offset;
                        row_offset += this_rows;

                        scope.spawn(move || {
                            for i in 0..this_rows {
                                let global_row = thread_row_offset + i;
                                let lhs_row = &lhs_data[global_row * k..(global_row + 1) * k];
                                let out_row = &mut thread_out[i * n..(i + 1) * n];
                                process_row_simd_tiled::<B, S>(
                                    simd,
                                    lhs_row,
                                    rhs_blocks,
                                    out_row,
                                    n,
                                    blocks_per_weight_row,
                                );
                            }
                        });
                    }
                });
            }
        } else {
            // Sequential processing for small matrices (m=2,3)
            for i in 0..m {
                let lhs_row = &lhs_data[i * k..(i + 1) * k];
                let out_row = &mut out_data[i * n..(i + 1) * n];
                process_row_simd_tiled::<B, S>(
                    simd,
                    lhs_row,
                    rhs_blocks,
                    out_row,
                    n,
                    blocks_per_weight_row,
                );
            }
        }
    }
}

/// Process a range of output columns for m=1 parallelization
#[inline(always)]
fn process_row_simd_range<B: GgufBlock, S: Simd>(
    simd: S,
    lhs_row: &[f32],
    rhs_blocks: &[B],
    out_chunk: &mut [f32],
    start_n: usize,
    chunk_n: usize,
    blocks_per_weight_row: usize,
) where
    B::Dequantized: AsRef<[f32]>,
{
    for (i, out_elem) in out_chunk.iter_mut().enumerate().take(chunk_n) {
        let n_out = start_n + i;
        *out_elem =
            compute_dot_product::<B, S>(simd, lhs_row, rhs_blocks, n_out, blocks_per_weight_row);
    }
}

/// Process a single output row with SIMD using 4-way tiling for better ILP
#[inline(always)]
fn process_row_simd_tiled<B: GgufBlock, S: Simd>(
    simd: S,
    lhs_row: &[f32],
    rhs_blocks: &[B],
    out_row: &mut [f32],
    n: usize,
    blocks_per_weight_row: usize,
) where
    B::Dequantized: AsRef<[f32]>,
{
    // Process 4 output columns at a time for better instruction-level parallelism
    const TILE: usize = 4;
    let n_tiles = n / TILE;
    let n_remainder = n % TILE;

    for tile in 0..n_tiles {
        let base = tile * TILE;

        // Initialize 4 accumulators
        let mut acc0 = simd.splat_f32s(0.0);
        let mut acc1 = simd.splat_f32s(0.0);
        let mut acc2 = simd.splat_f32s(0.0);
        let mut acc3 = simd.splat_f32s(0.0);
        let mut scalar_acc = [0.0f32; TILE];

        // Process all blocks, accumulating into all 4 outputs
        for block_idx in 0..blocks_per_weight_row {
            let input_block_start = block_idx * B::BLOCK_SIZE;
            let input_block = &lhs_row[input_block_start..input_block_start + B::BLOCK_SIZE];
            let (inp_simd, inp_tail) = S::as_simd_f32s(input_block);

            // Dequantize and accumulate for each of the 4 output columns
            let deq0 = rhs_blocks[base * blocks_per_weight_row + block_idx].dequantize();
            let deq1 = rhs_blocks[(base + 1) * blocks_per_weight_row + block_idx].dequantize();
            let deq2 = rhs_blocks[(base + 2) * blocks_per_weight_row + block_idx].dequantize();
            let deq3 = rhs_blocks[(base + 3) * blocks_per_weight_row + block_idx].dequantize();

            let (deq0_simd, deq0_tail) = S::as_simd_f32s(deq0.as_ref());
            let (deq1_simd, deq1_tail) = S::as_simd_f32s(deq1.as_ref());
            let (deq2_simd, deq2_tail) = S::as_simd_f32s(deq2.as_ref());
            let (deq3_simd, deq3_tail) = S::as_simd_f32s(deq3.as_ref());

            // SIMD accumulation for all 4 outputs
            for (i, &inp_vec) in inp_simd.iter().enumerate() {
                acc0 = simd.mul_add_f32s(inp_vec, deq0_simd[i], acc0);
                acc1 = simd.mul_add_f32s(inp_vec, deq1_simd[i], acc1);
                acc2 = simd.mul_add_f32s(inp_vec, deq2_simd[i], acc2);
                acc3 = simd.mul_add_f32s(inp_vec, deq3_simd[i], acc3);
            }

            // Scalar tail
            for (i, &inp_val) in inp_tail.iter().enumerate() {
                scalar_acc[0] += inp_val * deq0_tail[i];
                scalar_acc[1] += inp_val * deq1_tail[i];
                scalar_acc[2] += inp_val * deq2_tail[i];
                scalar_acc[3] += inp_val * deq3_tail[i];
            }
        }

        // Reduce and store results
        out_row[base] = <SumOp as SimdReduceOp<f32>>::reduce_simd_vec(simd, acc0) + scalar_acc[0];
        out_row[base + 1] =
            <SumOp as SimdReduceOp<f32>>::reduce_simd_vec(simd, acc1) + scalar_acc[1];
        out_row[base + 2] =
            <SumOp as SimdReduceOp<f32>>::reduce_simd_vec(simd, acc2) + scalar_acc[2];
        out_row[base + 3] =
            <SumOp as SimdReduceOp<f32>>::reduce_simd_vec(simd, acc3) + scalar_acc[3];
    }

    // Handle remainder
    for i in 0..n_remainder {
        let n_out = n_tiles * TILE + i;
        out_row[n_out] =
            compute_dot_product::<B, S>(simd, lhs_row, rhs_blocks, n_out, blocks_per_weight_row);
    }
}

/// Compute a single dot product for one output column
#[inline(always)]
fn compute_dot_product<B: GgufBlock, S: Simd>(
    simd: S,
    lhs_row: &[f32],
    rhs_blocks: &[B],
    n_out: usize,
    blocks_per_weight_row: usize,
) -> f32
where
    B::Dequantized: AsRef<[f32]>,
{
    let mut acc = simd.splat_f32s(0.0);
    let mut scalar_acc = 0.0f32;

    for block_idx in 0..blocks_per_weight_row {
        let weight_block_idx = n_out * blocks_per_weight_row + block_idx;
        let input_block_start = block_idx * B::BLOCK_SIZE;

        let dequantized = rhs_blocks[weight_block_idx].dequantize();
        let dequantized_slice = dequantized.as_ref();
        let input_block = &lhs_row[input_block_start..input_block_start + B::BLOCK_SIZE];

        let (inp_simd, inp_tail) = S::as_simd_f32s(input_block);
        let (deq_simd, deq_tail) = S::as_simd_f32s(dequantized_slice);

        for (&inp_vec, &deq_vec) in inp_simd.iter().zip(deq_simd.iter()) {
            acc = simd.mul_add_f32s(inp_vec, deq_vec, acc);
        }

        for (&inp_val, &deq_val) in inp_tail.iter().zip(deq_tail.iter()) {
            scalar_acc += inp_val * deq_val;
        }
    }

    <SumOp as SimdReduceOp<f32>>::reduce_simd_vec(simd, acc) + scalar_acc
}