package scientific.stats.continuous

import std.math.*
import std.unittest.*
import std.unittest.testmacro.*

import scientific.linear.*
import scientific.numbers.*
import scientific.stats.random.*

/* For evaluation of Gamma function */
let _lanczos_g = 7
let _lanczos_n = 9
let _lanczos_p = [
    0.99999999999980993,
    676.5203681218851,
    -1259.1392167224028,
    771.32342877765313,
    -176.61502916214059,
    12.507343278686905,
    -0.13857109526572012,
    9.9843695780195716e-6,
    1.5056327351493116e-7
]

/*
 * Simple implementation using Lanczos approximation
 */
public func gamma(x: Float64): Float64 {
    if (x < 0.5) {
        // Reflection formula
        return Float64.getPI() / (sin(Float64.getPI() * x) * gamma(1.0 - x))  
    }
    let z = x - 1.0
    var res = _lanczos_p[0]
    for (i in 1.._lanczos_p.size) {
        res += _lanczos_p[i] / (z + Float64(i))
    }
    let t = x - 1.0 + Float64(_lanczos_g) + 0.5
    return sqrt(2.0 * Float64.getPI()) * (t ** (x - 0.5)) * exp(-t) * res
}

/*
 * log gamma value
 */
public func gammaLog(x: Float64): Float64 {
    if (x < 0.5) {
        // Reflection formula
        return log(Float64.getPI()) - log(abs(sin(Float64.getPI() * x))) - gammaLog(1.0 - x)  
    }
    let z = x - 1.0
    var res = _lanczos_p[0]
    for (i in 1.._lanczos_p.size) {
        res += _lanczos_p[i] / (z + Float64(i))
    }
    let t = x - 1.0 + Float64(_lanczos_g) + 0.5
    return log(sqrt(2.0 * Float64.getPI())) + log(t) * (x - 0.5) - t + log(res)
}

/*
 * log multigamma value
 */
public func multiGammaLog(rval: Float64, dim: Int64): Float64 {
    let d = Float64(dim)
    let res1 = d * (d - 1.0) / 4.0 * log(Float64.getPI()) 
    var res2 = 0.0
    for (i in 0..dim) {
        res2 += gammaLog(rval - Float64(i) / 2.0)
    }
    return res1 + res2
}

/*
 * Probability density function for gamma distribution
 */
public func gammaPDF(x: Float64, alpha: Float64): Float64 {
    let ans: Float64 = pow(x, alpha - 1.0) * exp(-x) / gamma(alpha)
    return ans
}

/*
 * Cumulative probability density function for gamma distribution
 */
public func igam(a: Float64, x: Float64, max_iter!: Int64 = 100, tol!: Float64 = 1.0e-14): Float64 {
    var result = 0.0
    var term = 1.0 / a
    for (n in 1..max_iter+1) {
        result += term
        term *= x / (a + Float64(n))
        if (abs(term) < tol) {
            break
        }
    }
    return result * (x**a) * exp(-x) / gamma(a)
}

public func gammaCDF(x_scaled: Float64, k: Float64): Float64 {
    if (x_scaled < 0.0) {
        return 0.0
    }
    return igam(k, x_scaled)
}

public func gammaPPF(pp: Float64, alpha: Float64, max_iter!: Int64 = 100, tol!: Float64 = 1e-14): Float64 {
    // Handle edge cases
    if (pp <= 0.0) {
        return 0.0
    }
    if (pp >= 1.0) {
        return Float64.Inf
    }

    // Initial guess (Wilson-Hilferty approximation)
    var x = alpha * (1.0 - 1.0/(3.0*alpha))
    var cdf = 0.0
    var pdf = 0.0

    // Newton-Raphson iteration
    for (_ in 0..max_iter) {
        // Calculate CDF and PDF at current x
        // Safely compute CDF and PDF
        if (x <= 0.0) {
            cdf = 0.0
            pdf = if (alpha < 1.0) { Float64.Inf } else { 0.0 }
        } else {
            cdf = gammaCDF(x, alpha)
            pdf = gammaPDF(x, alpha)
        }

        // Avoid division by zero
        if (pdf < 1e-100) {
            if (cdf < pp) {
                x *= 2.0  // Move right if we're below target
            } else {
                x /= 2.0  // Move left if we're above
            }
            continue
        }
            
        // Newton step
        let delta = (cdf - pp) / pdf
        var new_x = x - delta
        
        // Safeguard: keep x positive and finite
        if (new_x <= 0.0) {
            new_x = x * 0.5  // Back off more conservatively
        } else if (new_x == Float64.Inf) {
            new_x = if (cdf <= pp) { x * 2.0 } else { x * 0.5 }
        }

        // Check convergence
        if (abs(new_x - x) < tol * (1.0 + abs(x))) {
            return new_x
        }
        x = new_x
    }

    return x
}

public func igami(a: Float64, p: Float64): Float64 {
    return gammaPPF(1.0 - p, a)
}

/*
 * Sampling from the gamma distribution
 */
public func gamSample(r: Random, alpha: Float64): Float64 {
    let x: Float64 = r.nextFloat64()
    return gammaPPF(x, alpha)
}

@Test
public class TestGamma {
    @TestCase
    func testGammaInteger(): Unit {
        @Assert(approxEqual(gamma(1.0), 1.0,  atol:1e-13))
        @Assert(approxEqual(gamma(2.0), 1.0,  atol:1e-13))
        @Assert(approxEqual(gamma(3.0), 2.0,  atol:1e-13))
        @Assert(approxEqual(gamma(4.0), 6.0,  atol:1e-13))
        @Assert(approxEqual(gamma(5.0), 24.0, atol:1e-13))
    }

    @TestCase
    func testGammaPosHalfInteger(): Unit {
        @Assert(approxEqual(gamma(0.5), 1.7724538509055159, atol:1e-13))
        @Assert(approxEqual(gamma(1.5), 0.886226925452758,  atol:1e-13))
        @Assert(approxEqual(gamma(2.5), 1.3293403881791372, atol:1e-13))
        @Assert(approxEqual(gamma(3.5), 3.323350970447842,  atol:1e-13))
        @Assert(approxEqual(gamma(4.5), 11.631728396567446, atol:1e-13))
    }

    @TestCase
    func testGammaNegHalfInteger(): Unit {
        @Assert(approxEqual(gamma(-0.5), -3.544907701811032, atol:1e-13))
        @Assert(approxEqual(gamma(-1.5), 2.3632718012073544, atol:1e-13))
        @Assert(approxEqual(gamma(-2.5), -0.9453087204829417, atol:1e-13))
        @Assert(approxEqual(gamma(-3.5), 0.27008820585226917, atol:1e-13))
        @Assert(approxEqual(gamma(-4.5), -0.06001960130050425, atol:1e-13))
    }

    @TestCase
    func testGammaReflection(): Unit {
        let values: Vector<Float64> = linspace(0.5, 9.5, num: 10)
        for (i in 0..values.size()) {
            @Assert(approxEqual(
                gamma(values[i]) * gamma(1.0 - values[i]),
                Float64.getPI() / sin(Float64.getPI() * values[i]),
                atol: 1e-12
            ))
        }
    }

    @TestCase
    func testGammaLog(): Unit {
        @Assert(approxEqual(gammaLog(0.5), 0.5723649429247, atol:1e-13))
        @Assert(approxEqual(gammaLog(1.5), -0.12078223763524526, atol:1e-13))
        @Assert(approxEqual(gammaLog(2.5), 0.2846828704729192, atol:1e-13))
        @Assert(approxEqual(gammaLog(3.5), 1.2009736023470743, atol:1e-13))
        @Assert(approxEqual(gammaLog(4.5), 2.4537365708424423, atol:1e-13))
    }

    @TestCase
    func testGammaPDF(): Unit {
        @Assert(approxEqual(gammaPDF(1.0, 4.0), 0.06131324019524039,  atol:1e-13))
        @Assert(approxEqual(gammaPDF(3.0, 4.0), 0.22404180765538775,  atol:1e-13))
        @Assert(approxEqual(gammaPDF(5.0, 4.0), 0.1403738958142805,   atol:1e-13))
        @Assert(approxEqual(gammaPDF(7.0, 4.0), 0.052129252364199824, atol:1e-13))
        @Assert(approxEqual(gammaPDF(9.0, 4.0), 0.014994291196531574, atol:1e-13))
    }

    @TestCase
    func testGammaCDF(): Unit {
        @Assert(approxEqual(gammaCDF(1.0, 4.0), 0.01898815687615381,  atol:1e-13))
        @Assert(approxEqual(gammaCDF(3.0, 4.0), 0.35276811121776874,  atol:1e-13))
        @Assert(approxEqual(gammaCDF(5.0, 4.0), 0.7349740847026385,   atol:1e-13))
        @Assert(approxEqual(gammaCDF(7.0, 4.0), 0.9182345837552784,   atol:1e-13))
        @Assert(approxEqual(gammaCDF(9.0, 4.0), 0.9787735136970911,   atol:1e-13))
    }

    @TestCase
    func testGammaPPF(): Unit {
        @Assert(approxEqual(gammaPPF(0.1, 4.0), 1.7447695628249114,   atol:1e-13))
        @Assert(approxEqual(gammaPPF(0.3, 4.0), 2.7637110426126483,   atol:1e-13))
        @Assert(approxEqual(gammaPPF(0.5, 4.0), 3.672060748850897,    atol:1e-13))
        @Assert(approxEqual(gammaPPF(0.7, 4.0), 4.762229096535917,    atol:1e-13))
        @Assert(approxEqual(gammaPPF(0.9, 4.0), 6.680783068255865,    atol:1e-13))
    }

    @TestCase
    func testMultiGammaLog(): Unit {
        @Assert(approxEqual(multiGammaLog(23.5, 10), 457.1766747829046,    atol:1e-6))
        @Assert(approxEqual(multiGammaLog(10.0, 10), 107.44809807262865,   atol:1e-6))
        @Assert(approxEqual(multiGammaLog(5.0,  10), 35.81035866069934,    atol:1e-6))
    }
}