package scientific.stats.multivariate
import std.math.*
import std.unittest.*
import std.unittest.testmacro.*
import scientific.numbers.*
import scientific.linear.*
import scientific.stats.continuous.gammaLog
import scientific.stats.random.*
/*
* Log Probability density function.
*/
public func dirichletLogPDF(x: Vector<Float64>, alpha: Vector<Float64>): Float64 {
if (x.size() != alpha.size()) {
throw IllegalArgumentException("dirichletLogPDF: parameters dimensions incompatible.")
}
let k = x.size()
var s = 0.0
for (i in 0..k) {
s += x[i]
}
if (s != 1.0) {
throw IllegalArgumentException("dirichletLogPDF: sum of x should be 1.0")
}
var res1 = 0.0
var res2 = 0.0
var res3 = 0.0
for (i in 0..k) {
res1 += gammaLog(alpha[i])
res3 += (alpha[i] - 1.0) * log(x[i])
}
res2 = gammaLog(sum(alpha))
return - res1 + res2 + res3
}
/*
* Probability density function.
*/
public func dirichletPDF(x: Vector<Float64>, alpha: Vector<Float64>): Float64 {
if (x.size() != alpha.size()) {
throw IllegalArgumentException("dirichletPDF: parameters dimensions incompatible.")
}
let k = x.size()
var s = 0.0
for (i in 0..k) {
s += x[i]
}
if (s != 1.0) {
throw IllegalArgumentException("dirichletPDF: sum of x should be 1.0")
}
return exp(dirichletLogPDF(x, alpha))
}
/*
* The mean of the Dirichlet distribution
*/
public func dirichletMean(alpha: Vector<Float64>): Vector<Float64> {
let n = alpha.size()
let res = empty<Float64>(n)
let s = sum(alpha)
for(i in 0..n) {
res[i] = alpha[i] / s
}
return res
}
/*
* The var of the Dirichlet distribution
*/
public func dirichletVar(alpha: Vector<Float64>): Vector<Float64> {
let n = alpha.size()
let res = empty<Float64>(n)
let s = sum(alpha)
let temp = s * s * (s + 1.0) // TODO:计算顺序可以优化防止溢出
for(i in 0..n) {
res[i] = alpha[i] * (s - alpha[i]) / temp
}
return res
}
@Test
public class TestDirichlet {
@TestCase
func testDirichlet(): Unit {
// Test logpdf
let quantiles = vector([0.2, 0.2, 0.6])
let alpha = vector([0.4, 5.0, 15.0])
@Assert(approxEqual(dirichletLogPDF(quantiles, alpha), -1.2574327653159187, atol:1e-6))
// Test pdf
@Assert(approxEqual(dirichletPDF(quantiles, alpha), 0.2843831684937255, atol:1e-6))
// Test mean
let meanTrue = vector([0.01960784, 0.24509804, 0.73529412])
@Assert(approxEqual(dirichletMean(alpha), meanTrue, atol:1e-6))
// Test variance
let varTrue = vector([0.00089829, 0.00864603, 0.00909517])
@Assert(approxEqual(dirichletVar(alpha), varTrue, atol:1e-6))
}
}