package scientific.stats.continuous

import std.math.*
import std.unittest.*
import std.unittest.testmacro.*

import scientific.numbers.*
import scientific.stats.random.*

public func expit<T>(x: T): T where T <: Float<T> {
    return T.fromFloat(1.0) / (T.fromFloat(1.0) + exp(T.fromFloat(0.0) - x))
}

public func logit<T>(x: T): T where T <: Float<T> {
    return log(x / (T.fromFloat(1.0) - x))
}

/*
 * Log of Probability density function
 */
public func logisticLogPDF(x: Float64, loc!: Float64 = 0.0, scale!: Float64 = 1.0): Float64 {
    let y = (x - loc) / scale   

    let res = - y - 2.0 * log(1.0 + exp(-y))
    return res - log(scale)
}

/*
 * Probability density function
 */
public func logisticPDF(x: Float64, loc!: Float64 = 0.0, scale!: Float64 = 1.0): Float64 {
    let y = (x - loc) / scale

    let temp = logisticLogPDF(x, loc: loc, scale: scale)
    return exp(temp)
}


/*
 * Cumulative probability density function
 */
public func logisticCDF(x: Float64, loc!: Float64 = 0.0, scale!: Float64 = 1.0): Float64 {
    let y = (x - loc) / scale

    let res = expit(y)
    return res
}


/*
 * Cumulative probability density function
 */
public func logisticLogCDF(x: Float64, loc!: Float64 = 0.0, scale!: Float64 = 1.0): Float64 {
    let y = (x - loc) / scale

    let temp = logisticCDF(x, loc: loc, scale: scale)
    if (temp < 0.000001) {
        throw IllegalArgumentException("logisticLogCDF: return-value too small.")
    }

    return log(temp)
}


/*
 * PPF
 */
public func logisticPPF(q: Float64, loc!: Float64 = 0.0, scale!: Float64 = 1.0): Float64 {
    if (q <= 0.0 || q >= 1.0) {
        throw IllegalArgumentException("logisticPPF: quantile out of bound.")
    }

    let res = logit(q)
    return res * scale + loc
}


/*
 * compute the mean
 */
public func logisticMean(loc!: Float64 = 0.0, scale!: Float64 = 1.0): Float64 {
    return loc
}


/*
 * compute the var
 */
public func logisticVar(loc!: Float64 = 0.0, scale!: Float64 = 1.0): Float64 {
    return Float64.getPI() * Float64.getPI() / 3.0 * scale * scale
}

/*
 * compute the std
 */
public func logisticStd(loc!: Float64 = 0.0, scale!: Float64 = 1.0): Float64 {
    let temp = halfnormVar(loc: loc, scale: scale)

    if (temp < 0.000001) {
        throw IllegalArgumentException("logisticStd: return-value too small.")
    }

    return sqrt(temp)
}

@Test
public class TestLogistic {
    @TestCase
    func testLogistic(): Unit {
        @Assert(approxEqual(logisticLogPDF(3.0, loc: 2.0, scale: 1.0), -1.6265233750364456,  atol:1e-13))
        @Assert(approxEqual(logisticLogCDF(3.0, loc: 2.0, scale: 1.0), -0.31326168751822286, atol:1e-13))
        @Assert(approxEqual(logisticPPF(0.7, loc: 2.0, scale: 1.0),     2.8472978603872034,  atol:1e-13))
    }
}