package scientific.stats.normal

import std.math.*
import std.unittest.*
import std.unittest.testmacro.*

import scientific.stats.random.Random
import scientific.linear.empty
import scientific.linear.Vector
import scientific.linear.Matrix
import scientific.numbers.approxEqual


/*
 * Reference [1]: The Truncated Normal Distribution
 * https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
 */

/*
 * Probability density function for normal distribution
 */
public func normalPDF(x: Float64): Float64 {
    return 1.0 / sqrt(2.0*Float64.getPI()) * exp(-x*x/2.0)
}

/*
 * Cumulative density function for normal distribution
 */
public func normalCDF(x: Float64): Float64 {
    return (1.0 + erf(x/sqrt(2.0))) / 2.0
}

foreign func malloc(size: UIntNative) : CPointer<Unit>
foreign func ppnd(p: Float64, ifault: CPointer<IntNative>): Float64

/*
 * Percent point function (inverse of cdf) for normal distribution
 */
public func normalPPF(x: Float64): Float64 {
    var p1 = unsafe { malloc(4) }
    var ifault = unsafe { CPointer<IntNative>(p1) }
    
    let v = unsafe { ppnd(x, ifault) }
    let isFault = unsafe { ifault.read() }
    if (isFault == 1) {
        throw IllegalArgumentException("normalPPF: input not between 0 and 1")
    }
    return v
}

/*
 * Sampling from the normal distribution
 */
public func normalSampleFloat64(m: Random): Float64 {
    let p = m.nextFloat64()
    return normalPPF(p)
}

/*
 * Probability density function for general normal distribution
 */
public func normalPDF(x: Float64, mean: Float64, std: Float64): Float64 {
    let temp = 1.0 / sqrt(2.0 * Float64.getPI()) / std * exp(-pow(x - mean, 2.0) / (2.0 * pow(std, 2.0)))
    return temp
}

public func normalPDFAcc(x: Float64, mean: Float64, std: Float64): (Float64, (Float64) -> (Float64, Float64)) {
    let temp = 1.0 / sqrt(2.0 * Float64.getPI()) / std * exp(-pow(x - mean, 2.0) / (2.0 * pow(std, 2.0)))
    return (
    temp, 
        { dy: Float64 =>
            return (temp * (-(mean - x) / pow(std, 2.0)),  temp * (- 1.0 / std + pow(x - mean, 2.0) / pow(std, 3.0)))
        }
    )
}


/*
 * Cumulative density function for general normal distribution
 */
public func normalCDF(x: Float64, mean: Float64, std: Float64): Float64 {
    let x2 = (x - mean) / std
    return normalCDF(x2)
}

/*
 * Percent point function (inverse of cdf) for general normal distribution
 */
public func normalPPF(x: Float64, mean: Float64, std: Float64): Float64 {
    let v = normalPPF(x)
    return mean + std * v
}

/*
 * Sampling from the general normal distribution
 */
public func normalSampleFloat64(m: Random, mean: Float64, std: Float64): Float64 {
    let p = m.nextFloat64()
    return normalPPF(p, mean, std)
}

/*
 * Probability density function for truncated normal distribution
 */
public func truncNormalPDF(x: Float64, mean: Float64, std: Float64, start: Float64, end: Float64): Float64 {
    if (x <= start) {
        return 0.0
    }
    if (x >= end) {
        return 0.0
    }
    
    let p = normalPDF(x, mean, std)
    let cStart = normalCDF(start, mean, std)
    let cEnd = normalCDF(end, mean, std)
    return p / (cEnd - cStart)
}

/*
 * Cumulative density function for truncated normal distribution
 */
public func truncNormalCDF(x: Float64, mean: Float64, std: Float64, start: Float64, end: Float64): Float64 {
    if (x <= start) {
        return 0.0
    }
    if (x >= end) {
        return 1.0
    }

    let cx = normalCDF(x, mean, std)
    let cStart = normalCDF(start, mean, std)
    let cEnd = normalCDF(end, mean, std)
    return (cx - cStart) / (cEnd - cStart)
}

/*
 * Percentage point function (inverse of cdf) for truncated normal distribution
 */
public func truncNormalPPF(x: Float64, mean: Float64, std: Float64, start: Float64, end: Float64): Float64 {
    if (x < 0.0 || x > 1.0) {
        throw IllegalArgumentException("truncNormalPPF: input not between 0 and 1")
    }
    let cStart = normalCDF(start, mean, std)
    let cEnd = normalCDF(end, mean, std)
    let x2 = cStart + x * (cEnd - cStart)
    return normalPPF(x2, mean, std)
}

/*
 * Sampling from the truncated normal distribution
 */
public func truncNormalSampleFloat64(m: Random, mean: Float64, std: Float64, start: Float64, end: Float64): Float64 {
    let p = m.nextFloat64()
    return truncNormalPPF(p, mean, std, start, end)
}

public func rand(r: Random, size: Int64): Vector<Float64> {
    let x = empty<Float64>(size)
    for (i in 0..size) {
        x[i] = r.nextFloat64()
    }
    return x
}

public func rand(r: Random, size: Int64, lower: Float64, upper: Float64): Vector<Float64> {
    let x = empty<Float64>(size)
    for (i in 0..size) {
        x[i] = r.nextFloat64() * (upper - lower) + lower
    }
    return x
}

public func rand(r: Random, rows: Int64, cols: Int64): Matrix<Float64> {
    let x = empty<Float64>(rows, cols)
    for (i in 0..rows) {
        for (j in 0..cols) {
            x[i,j] = r.nextFloat64()
        }
    }
    return x
}

public func rand(r: Random, rows: Int64, cols: Int64, lower: Float64, upper: Float64): Matrix<Float64> {
    let x = empty<Float64>(rows, cols)
    for (i in 0..rows) {
        for (j in 0..cols) {
            x[i,j] = r.nextFloat64() * (upper - lower) + lower
        }
    }
    return x
}

public func randn(r: Random, size: Int64, mean: Float64, std: Float64): Vector<Float64> {
    let x = empty<Float64>(size)
    for (i in 0..size) {
        x[i] = normalSampleFloat64(r, mean, std)
    }
    return x
}

public func randn(r: Random, rows: Int64, cols: Int64, mean: Float64, std: Float64): Matrix<Float64> {
    let x = empty<Float64>(rows, cols)
    for (i in 0..rows) {
        for (j in 0..cols) {
            x[i,j] = normalSampleFloat64(r, mean, std)      
        }
    }
    return x
}

public func randn_trunc(r: Random, size: Int64, mean: Float64, std: Float64,
                        start: Float64, end: Float64): Vector<Float64> {
    let x = empty<Float64>(size)
    for (i in 0..size) {
        x[i] = truncNormalSampleFloat64(r, mean, std, start, end)
    }
    return x
}

public func randn_trunc(r: Random, rows: Int64, cols: Int64, mean: Float64, std: Float64,
                        start: Float64, end: Float64): Matrix<Float64> {
    let x = empty<Float64>(rows, cols)
    for (i in 0..rows) {
        for (j in 0..cols) {
            x[i,j] = truncNormalSampleFloat64(r, mean, std, start, end)
        }
    }
    return x
}


@Test
public class TestNormal {
    @TestCase
    func testNormalPDF(): Unit {
        @Assert(approxEqual(normalPDF(-2.0, 0.0, 1.0), 0.05399096651318806, atol:1e-13))
        @Assert(approxEqual(normalPDF(-1.0, 0.0, 1.0), 0.24197072451914337, atol:1e-13))
        @Assert(approxEqual(normalPDF(0.0,  0.0, 1.0), 0.3989422804014327,  atol:1e-13))
        @Assert(approxEqual(normalPDF(1.0,  0.0, 1.0), 0.24197072451914337, atol:1e-13))
        @Assert(approxEqual(normalPDF(2.0,  0.0, 1.0), 0.05399096651318806, atol:1e-13))
    }

    @TestCase
    func testNormalCDF(): Unit {
        @Assert(approxEqual(normalCDF(-2.0, 0.0, 1.0), 0.022750131948179195, atol:1e-13))
        @Assert(approxEqual(normalCDF(-1.0, 0.0, 1.0), 0.15865525393145707,  atol:1e-13))
        @Assert(approxEqual(normalCDF(0.0,  0.0, 1.0), 0.5,                  atol:1e-13))
        @Assert(approxEqual(normalCDF(1.0,  0.0, 1.0), 0.8413447460685429,   atol:1e-13))
        @Assert(approxEqual(normalCDF(2.0,  0.0, 1.0), 0.9772498680518208,   atol:1e-13))
    }

    @TestCase
    func testNormalPPF(): Unit {
        @Assert(approxEqual(normalPPF(0.2,  0.0, 1.0), -0.8416212335729142,  atol:1e-6))
        @Assert(approxEqual(normalPPF(0.4,  0.0, 1.0), -0.2533471031357997,  atol:1e-6))
        @Assert(approxEqual(normalPPF(0.6,  0.0, 1.0), 0.2533471031357997,   atol:1e-6))
        @Assert(approxEqual(normalPPF(0.8,  0.0, 1.0), 0.8416212335729143,   atol:1e-6))
    }

    // Test values from Reference [1], page 20
    let mean : Float64 = 100.0
    let std : Float64 = 25.0
    let start : Float64 = 50.0
    let end : Float64 = 150.0

    @TestCase
    func testTruncNormalPDF(): Unit {
        @Assert(approxEqual(truncNormalPDF( 81.630, mean, std, start, end), 0.012763, atol:1e-6))
        @Assert(approxEqual(truncNormalPDF(137.962, mean, std, start, end), 0.005278, atol:1e-6))
        @Assert(approxEqual(truncNormalPDF(122.367, mean, std, start, end), 0.011204, atol:1e-6))
        @Assert(approxEqual(truncNormalPDF(103.704, mean, std, start, end), 0.016536, atol:1e-6))
        @Assert(approxEqual(truncNormalPDF( 94.899, mean, std, start, end), 0.016374, atol:1e-6))
    }

    @TestCase
    func testTruncNormalCDF(): Unit {
        // A little strange that the answer is slightly different from that in the reference
        @Assert(approxEqual(truncNormalCDF( 81.630, mean, std, start, end), 0.218419, atol:1e-6))
        @Assert(approxEqual(truncNormalCDF(137.962, mean, std, start, end), 0.956316, atol:1e-6))
        @Assert(approxEqual(truncNormalCDF(122.367, mean, std, start, end), 0.829514, atol:1e-6))
        @Assert(approxEqual(truncNormalCDF(103.704, mean, std, start, end), 0.561699, atol:1e-6))
        @Assert(approxEqual(truncNormalCDF( 94.899, mean, std, start, end), 0.415308, atol:1e-6))
    }

    @TestCase
    func testTruncNormalPPF(): Unit {
        @Assert(approxEqual(truncNormalPPF(0.218418, mean, std, start, end), 81.629951, atol:1e-6))
        @Assert(approxEqual(truncNormalPPF(0.956318, mean, std, start, end), 137.962423, atol:1e-6))
        @Assert(approxEqual(truncNormalPPF(0.829509, mean, std, start, end), 122.366564, atol:1e-6))
        @Assert(approxEqual(truncNormalPPF(0.561695, mean, std, start, end), 103.703754, atol:1e-6))
        @Assert(approxEqual(truncNormalPPF(0.415307, mean, std, start, end), 94.898964, atol:1e-6))
    }
}