eaebec39创建于 2024年10月24日历史提交
package scientific.stats.multivariate

import std.math.*
import std.unittest.*
import std.unittest.testmacro.*

import scientific.numbers.*
import scientific.linear.*
import scientific.stats.continuous.gammaLog
import scientific.stats.random.*


/*
 *  Log Probability density function.
 */
public func dirichletLogPDF(x: Vector<Float64>, alpha: Vector<Float64>): Float64 {
    if (x.size() != alpha.size()) {
        throw IllegalArgumentException("dirichletLogPDF: parameters dimensions incompatible.")
    }

    let k = x.size()
    var s = 0.0
    for (i in 0..k) {
        s += x[i]
    }

    if (s != 1.0) {
        throw IllegalArgumentException("dirichletLogPDF: sum of x should be 1.0")
    }

    var res1 = 0.0
    var res2 = 0.0
    var res3 = 0.0

    for (i in 0..k) {
        res1 += gammaLog(alpha[i])
        res3 += (alpha[i] - 1.0) * log(x[i])
    }

    res2 = gammaLog(sum(alpha))

    return - res1 + res2 + res3
}


/*
 *  Probability density function.
 */
public func dirichletPDF(x: Vector<Float64>, alpha: Vector<Float64>): Float64 {
    if (x.size() != alpha.size()) {
        throw IllegalArgumentException("dirichletPDF: parameters dimensions incompatible.")
    }

    let k = x.size()
    var s = 0.0
    for (i in 0..k) {
        s += x[i]
    }

    if (s != 1.0) {
        throw IllegalArgumentException("dirichletPDF: sum of x should be 1.0")
    }

    return exp(dirichletLogPDF(x, alpha))
}

/*
 * The mean of the Dirichlet distribution
 */
public func dirichletMean(alpha: Vector<Float64>): Vector<Float64> {
    let n = alpha.size()
    let res = empty<Float64>(n)
    let s = sum(alpha)
    for(i in 0..n) {    
        res[i] = alpha[i] / s
    }

    return res
}

/*
 * The var of the Dirichlet distribution
 */
public func dirichletVar(alpha: Vector<Float64>): Vector<Float64> {
    let n = alpha.size()
    let res = empty<Float64>(n)
    let s = sum(alpha)
    let temp = s * s * (s + 1.0)    // TODO:计算顺序可以优化防止溢出
    for(i in 0..n) {    
        res[i] = alpha[i] * (s - alpha[i]) / temp
    }

    return res
}

@Test
public class TestDirichlet {
    @TestCase
    func testDirichlet(): Unit {
        // Test logpdf
        let quantiles = vector([0.2, 0.2, 0.6])
        let alpha = vector([0.4, 5.0, 15.0])
        @Assert(approxEqual(dirichletLogPDF(quantiles, alpha), -1.2574327653159187, atol:1e-6))

        // Test pdf
        @Assert(approxEqual(dirichletPDF(quantiles, alpha),    0.2843831684937255,  atol:1e-6))

        // Test mean
        let meanTrue = vector([0.01960784, 0.24509804, 0.73529412])
        @Assert(approxEqual(dirichletMean(alpha), meanTrue, atol:1e-6))

        // Test variance
        let varTrue = vector([0.00089829, 0.00864603, 0.00909517])
        @Assert(approxEqual(dirichletVar(alpha), varTrue, atol:1e-6))
    }
}