//! Elementwise (unary) tensor operations: Neg, Abs, Sqrt

use std::ops::Neg as StdNeg;

use pulp::Simd;

use crate::{ConcreteTensor, SimdElement, TensorBacking, materialize_expr};
use fusor_types::Layout;

/// Trait for unary operations that have SIMD support
pub trait SimdUnaryOp<E: SimdElement>: Copy {
    /// Apply operation to SIMD vector
    fn apply_simd_vec<S: Simd>(simd: S, a: E::Simd<S>) -> E::Simd<S>;

    /// Apply operation to scalar
    fn apply_scalar(val: E) -> E;
}

// Unary operation markers
macro_rules! define_op_marker {
    ($($name:ident),* $(,)?) => {
        $(
            #[derive(Copy, Clone)]
            pub struct $name;
        )*
    };
}
define_op_marker!(
    NegOp, AbsOp, SqrtOp, ExpOp, Exp2Op, LogOp, Log2Op, SinOp, CosOp, TanOp, TanhOp, AsinOp,
    AcosOp, AtanOp, SinhOp, CoshOp, AsinhOp, AcoshOp, AtanhOp
);

// Macro for unary ops with SIMD support
macro_rules! impl_unary_op {
    ($op:ty, $scalar_fn:expr, $simd_method:ident, $elem:ty) => {
        impl SimdUnaryOp<$elem> for $op {
            #[inline(always)]
            fn apply_simd_vec<S: Simd>(
                simd: S,
                a: <$elem as SimdElement>::Simd<S>,
            ) -> <$elem as SimdElement>::Simd<S> {
                simd.$simd_method(a)
            }

            #[inline(always)]
            fn apply_scalar(val: $elem) -> $elem {
                let f: fn($elem) -> $elem = $scalar_fn;
                f(val)
            }
        }
    };
}

// NegOp implementations
impl_unary_op!(NegOp, |x: f32| -x, neg_f32s, f32);
impl_unary_op!(NegOp, |x: f64| -x, neg_f64s, f64);

// NegOp for integer types using subtraction from zero
macro_rules! impl_neg_int_op {
    ($elem:ty, $splat:ident, $sub:ident) => {
        impl SimdUnaryOp<$elem> for NegOp {
            #[inline(always)]
            fn apply_simd_vec<S: Simd>(
                simd: S,
                a: <$elem as SimdElement>::Simd<S>,
            ) -> <$elem as SimdElement>::Simd<S> {
                simd.$sub(simd.$splat(0), a)
            }

            #[inline(always)]
            fn apply_scalar(val: $elem) -> $elem {
                val.wrapping_neg()
            }
        }
    };
}

impl_neg_int_op!(i8, splat_i8s, sub_i8s);
impl_neg_int_op!(i16, splat_i16s, sub_i16s);
impl_neg_int_op!(i32, splat_i32s, sub_i32s);
impl_neg_int_op!(i64, splat_i64s, sub_i64s);

// AbsOp for floats (native SIMD support)
impl_unary_op!(AbsOp, |x: f32| x.abs(), abs_f32s, f32);
impl_unary_op!(AbsOp, |x: f64| x.abs(), abs_f64s, f64);

// AbsOp for integers using max(x, -x)
macro_rules! impl_abs_int_op {
    ($elem:ty, $splat:ident, $sub:ident, $max:ident) => {
        impl SimdUnaryOp<$elem> for AbsOp {
            #[inline(always)]
            fn apply_simd_vec<S: Simd>(
                simd: S,
                a: <$elem as SimdElement>::Simd<S>,
            ) -> <$elem as SimdElement>::Simd<S> {
                let zero = simd.$splat(0);
                let neg_a = simd.$sub(zero, a);
                simd.$max(a, neg_a)
            }

            #[inline(always)]
            fn apply_scalar(val: $elem) -> $elem {
                val.wrapping_abs()
            }
        }
    };
}

impl_abs_int_op!(i8, splat_i8s, sub_i8s, max_i8s);
impl_abs_int_op!(i16, splat_i16s, sub_i16s, max_i16s);
impl_abs_int_op!(i32, splat_i32s, sub_i32s, max_i32s);
impl_abs_int_op!(i64, splat_i64s, sub_i64s, max_i64s);

// Sqrt for floats
impl_unary_op!(SqrtOp, |x: f32| x.sqrt(), sqrt_f32s, f32);
impl_unary_op!(SqrtOp, |x: f64| x.sqrt(), sqrt_f64s, f64);

// Macro for scalar-only unary ops (no SIMD intrinsic available)
// Uses scalar evaluation per SIMD lane, which still benefits from fusion
macro_rules! impl_scalar_unary_op {
    ($op:ty, $scalar_fn:expr, $elem:ty) => {
        impl SimdUnaryOp<$elem> for $op {
            #[inline(always)]
            fn apply_simd_vec<S: Simd>(
                _simd: S,
                a: <$elem as SimdElement>::Simd<S>,
            ) -> <$elem as SimdElement>::Simd<S> {
                // Process each lane with scalar operation
                let lane_count = std::mem::size_of::<<$elem as SimdElement>::Simd<S>>()
                    / std::mem::size_of::<$elem>();
                let mut temp = [<$elem>::default(); crate::MAX_SIMD_LANES];

                // Safe: cast SIMD ref to scalar slice via bytemuck
                let input_slice: &[$elem] = pulp::bytemuck::cast_slice(std::slice::from_ref(&a));
                temp[..lane_count].copy_from_slice(input_slice);

                let f: fn($elem) -> $elem = $scalar_fn;
                for i in 0..lane_count {
                    temp[i] = f(temp[i]);
                }

                // Safe: reconstruct SIMD from scalar slice via as_simd
                let (simd_slice, _) = <$elem as SimdElement>::as_simd::<S>(&temp[..lane_count]);
                simd_slice[0]
            }

            #[inline(always)]
            fn apply_scalar(val: $elem) -> $elem {
                let f: fn($elem) -> $elem = $scalar_fn;
                f(val)
            }
        }
    };
}

// Transcendental ops for f32
impl_scalar_unary_op!(ExpOp, |x: f32| x.exp(), f32);
impl_scalar_unary_op!(Exp2Op, |x: f32| x.exp2(), f32);
impl_scalar_unary_op!(LogOp, |x: f32| x.ln(), f32);
impl_scalar_unary_op!(Log2Op, |x: f32| x.log2(), f32);
impl_scalar_unary_op!(SinOp, |x: f32| x.sin(), f32);
impl_scalar_unary_op!(CosOp, |x: f32| x.cos(), f32);
impl_scalar_unary_op!(TanOp, |x: f32| x.tan(), f32);
impl_scalar_unary_op!(TanhOp, |x: f32| x.tanh(), f32);

// Transcendental ops for f64
impl_scalar_unary_op!(ExpOp, |x: f64| x.exp(), f64);
impl_scalar_unary_op!(Exp2Op, |x: f64| x.exp2(), f64);
impl_scalar_unary_op!(LogOp, |x: f64| x.ln(), f64);
impl_scalar_unary_op!(Log2Op, |x: f64| x.log2(), f64);
impl_scalar_unary_op!(SinOp, |x: f64| x.sin(), f64);
impl_scalar_unary_op!(CosOp, |x: f64| x.cos(), f64);
impl_scalar_unary_op!(TanOp, |x: f64| x.tan(), f64);
impl_scalar_unary_op!(TanhOp, |x: f64| x.tanh(), f64);

// Additional inverse trig and hyperbolic ops for f32
impl_scalar_unary_op!(AsinOp, |x: f32| x.asin(), f32);
impl_scalar_unary_op!(AcosOp, |x: f32| x.acos(), f32);
impl_scalar_unary_op!(AtanOp, |x: f32| x.atan(), f32);
impl_scalar_unary_op!(SinhOp, |x: f32| x.sinh(), f32);
impl_scalar_unary_op!(CoshOp, |x: f32| x.cosh(), f32);
impl_scalar_unary_op!(AsinhOp, |x: f32| x.asinh(), f32);
impl_scalar_unary_op!(AcoshOp, |x: f32| x.acosh(), f32);
impl_scalar_unary_op!(AtanhOp, |x: f32| x.atanh(), f32);

// Additional inverse trig and hyperbolic ops for f64
impl_scalar_unary_op!(AsinOp, |x: f64| x.asin(), f64);
impl_scalar_unary_op!(AcosOp, |x: f64| x.acos(), f64);
impl_scalar_unary_op!(AtanOp, |x: f64| x.atan(), f64);
impl_scalar_unary_op!(SinhOp, |x: f64| x.sinh(), f64);
impl_scalar_unary_op!(CoshOp, |x: f64| x.cosh(), f64);
impl_scalar_unary_op!(AsinhOp, |x: f64| x.asinh(), f64);
impl_scalar_unary_op!(AcoshOp, |x: f64| x.acosh(), f64);
impl_scalar_unary_op!(AtanhOp, |x: f64| x.atanh(), f64);

// f16 unary ops: pulp has no native f16 SIMD, so `f16::Simd<S> = F16Scalar`
// (a single-lane wrapper). Each op forwards through f32 for math correctness
// and re-rounds to f16. See cpu/src/lib.rs:F16Scalar.
macro_rules! impl_f16_unary_op {
    ($op:ty, $f:expr) => {
        impl SimdUnaryOp<half::f16> for $op {
            #[inline(always)]
            fn apply_simd_vec<S: Simd>(_simd: S, a: crate::F16Scalar) -> crate::F16Scalar {
                let f: fn(half::f16) -> half::f16 = $f;
                crate::F16Scalar(f(a.0))
            }

            #[inline(always)]
            fn apply_scalar(val: half::f16) -> half::f16 {
                let f: fn(half::f16) -> half::f16 = $f;
                f(val)
            }
        }
    };
}

impl_f16_unary_op!(NegOp, |x: half::f16| -x);
impl_f16_unary_op!(AbsOp, |x: half::f16| half::f16::from_f32(x.to_f32().abs()));
impl_f16_unary_op!(SqrtOp, |x: half::f16| half::f16::from_f32(
    x.to_f32().sqrt()
));
impl_f16_unary_op!(ExpOp, |x: half::f16| half::f16::from_f32(x.to_f32().exp()));
impl_f16_unary_op!(Exp2Op, |x: half::f16| half::f16::from_f32(
    x.to_f32().exp2()
));
impl_f16_unary_op!(LogOp, |x: half::f16| half::f16::from_f32(x.to_f32().ln()));
impl_f16_unary_op!(Log2Op, |x: half::f16| half::f16::from_f32(
    x.to_f32().log2()
));
impl_f16_unary_op!(SinOp, |x: half::f16| half::f16::from_f32(x.to_f32().sin()));
impl_f16_unary_op!(CosOp, |x: half::f16| half::f16::from_f32(x.to_f32().cos()));
impl_f16_unary_op!(TanOp, |x: half::f16| half::f16::from_f32(x.to_f32().tan()));
impl_f16_unary_op!(TanhOp, |x: half::f16| half::f16::from_f32(
    x.to_f32().tanh()
));
impl_f16_unary_op!(AsinOp, |x: half::f16| half::f16::from_f32(
    x.to_f32().asin()
));
impl_f16_unary_op!(AcosOp, |x: half::f16| half::f16::from_f32(
    x.to_f32().acos()
));
impl_f16_unary_op!(AtanOp, |x: half::f16| half::f16::from_f32(
    x.to_f32().atan()
));
impl_f16_unary_op!(SinhOp, |x: half::f16| half::f16::from_f32(
    x.to_f32().sinh()
));
impl_f16_unary_op!(CoshOp, |x: half::f16| half::f16::from_f32(
    x.to_f32().cosh()
));
impl_f16_unary_op!(AsinhOp, |x: half::f16| half::f16::from_f32(
    x.to_f32().asinh()
));
impl_f16_unary_op!(AcoshOp, |x: half::f16| half::f16::from_f32(
    x.to_f32().acosh()
));
impl_f16_unary_op!(AtanhOp, |x: half::f16| half::f16::from_f32(
    x.to_f32().atanh()
));

/// Macro to define unary tensor operations (Neg, Abs, Sqrt)
macro_rules! define_unary_tensor_op {
    ($name:ident, $simd_op:ty) => {
        pub struct $name<E: SimdElement, const R: usize, T: TensorBacking<R, Elem = E>> {
            input: T,
            _marker: std::marker::PhantomData<E>,
        }

        impl<E, const R: usize, T> $name<E, R, T>
        where
            E: SimdElement,
            T: TensorBacking<R, Elem = E>,
        {
            pub fn new(input: T) -> Self {
                Self {
                    input,
                    _marker: std::marker::PhantomData,
                }
            }
        }

        impl<E, const R: usize, T> crate::LazyBacking for $name<E, R, T>
        where
            E: SimdElement + Default,
            $simd_op: SimdUnaryOp<E>,
            T: TensorBacking<R, Elem = E>,
        {
            type Elem = E;

            #[inline(always)]
            fn eval_scalar(&self, idx: usize) -> E {
                <$simd_op>::apply_scalar(self.input.eval_scalar(idx))
            }

            #[inline(always)]
            fn eval_simd<S: Simd>(&self, simd: S, base_idx: usize) -> E::Simd<S> {
                <$simd_op>::apply_simd_vec(simd, self.input.eval_simd(simd, base_idx))
            }
        }

        impl<E, const R: usize, T> TensorBacking<R> for $name<E, R, T>
        where
            E: SimdElement + Default,
            $simd_op: SimdUnaryOp<E>,
            T: TensorBacking<R, Elem = E>,
        {
            fn layout(&self) -> Layout {
                Layout::contiguous(self.input.layout().shape())
            }

            fn to_concrete(&self) -> ConcreteTensor<E, R> {
                let shape: [usize; R] = self
                    .input
                    .layout()
                    .shape()
                    .try_into()
                    .expect("Shape length mismatch");
                materialize_expr(self, shape)
            }
        }
    };
    ($name:ident, $simd_op:ty, $std_trait:ident) => {
        pub struct $name<E: SimdElement, const R: usize, T: TensorBacking<R, Elem = E>> {
            input: T,
            _marker: std::marker::PhantomData<E>,
        }

        impl<E, const R: usize, T> $name<E, R, T>
        where
            E: SimdElement,
            T: TensorBacking<R, Elem = E>,
        {
            pub fn new(input: T) -> Self {
                Self {
                    input,
                    _marker: std::marker::PhantomData,
                }
            }
        }

        impl<E, const R: usize, T> crate::LazyBacking for $name<E, R, T>
        where
            E: SimdElement + $std_trait<Output = E> + Default,
            $simd_op: SimdUnaryOp<E>,
            T: TensorBacking<R, Elem = E>,
        {
            type Elem = E;

            #[inline(always)]
            fn eval_scalar(&self, idx: usize) -> E {
                <$simd_op>::apply_scalar(self.input.eval_scalar(idx))
            }

            #[inline(always)]
            fn eval_simd<S: Simd>(&self, simd: S, base_idx: usize) -> E::Simd<S> {
                <$simd_op>::apply_simd_vec(simd, self.input.eval_simd(simd, base_idx))
            }
        }

        impl<E, const R: usize, T> TensorBacking<R> for $name<E, R, T>
        where
            E: SimdElement + $std_trait<Output = E> + Default,
            $simd_op: SimdUnaryOp<E>,
            T: TensorBacking<R, Elem = E>,
        {
            fn layout(&self) -> Layout {
                Layout::contiguous(self.input.layout().shape())
            }

            fn to_concrete(&self) -> ConcreteTensor<E, R> {
                let shape: [usize; R] = self
                    .input
                    .layout()
                    .shape()
                    .try_into()
                    .expect("Shape length mismatch");
                materialize_expr(self, shape)
            }
        }
    };
}

// Unary tensor operations
define_unary_tensor_op!(Neg, NegOp, StdNeg);
define_unary_tensor_op!(Abs, AbsOp);
define_unary_tensor_op!(Sqrt, SqrtOp);

// Transcendental tensor operations
define_unary_tensor_op!(Exp, ExpOp);
define_unary_tensor_op!(Exp2, Exp2Op);
define_unary_tensor_op!(Log, LogOp);
define_unary_tensor_op!(Log2, Log2Op);
define_unary_tensor_op!(Sin, SinOp);
define_unary_tensor_op!(Cos, CosOp);
define_unary_tensor_op!(Tan, TanOp);
define_unary_tensor_op!(Tanh, TanhOp);

// Additional inverse trig and hyperbolic tensor operations
define_unary_tensor_op!(Asin, AsinOp);
define_unary_tensor_op!(Acos, AcosOp);
define_unary_tensor_op!(Atan, AtanOp);
define_unary_tensor_op!(Sinh, SinhOp);
define_unary_tensor_op!(Cosh, CoshOp);
define_unary_tensor_op!(Asinh, AsinhOp);
define_unary_tensor_op!(Acosh, AcoshOp);
define_unary_tensor_op!(Atanh, AtanhOp);