eaebec39创建于 2024年10月24日历史提交
package scientific.stats.multivariate

import std.math.*
import std.unittest.*
import std.unittest.testmacro.*

import scientific.numbers.*
import scientific.linear.*
import scientific.stats.random.*
import scientific.stats.continuous.multiGammaLog

/*
 *  Probability density function.
 *  df - number of degree of freedom
 */
public func invwishartLogPDF(x: Vector<Float64>, df: Int64, cov: Float64): Vector<Float64> {
    let n = x.size()
    let res = empty<Float64>(n)
    let d = Float64(df)
    let p = 1.0

    let res1 = d * p * 0.5 * log(2.0)
    let res2 = multiGammaLog(d * 0.5, Int64(p))
    let res3 = d * 0.5 * log(cov)
    let res4 = - res1 - res2 + res3

    for (i in 0..n) {
        res[i] = - (d + p + 1.0) * 0.5 * log(x[i]) - cov / x[i] * 0.5 + res4
    }

    return res
}

public func invwishartPDF(x: Vector<Float64>, df: Int64, cov: Float64): Vector<Float64> {
    let n = x.size()
    let res = empty<Float64>(n)
    let resLog = invwishartLogPDF(x, df, cov)
    for(i in 0..n) {
        res[i] = exp(resLog[i])
    }
    return res
}

public func invwishartLogPDF(x: Matrix<Float64>, df: Int64, cov: Matrix<Float64>): Float64 {
    if (cov.getRows() != cov.getCols() || det(cov) <= 0.0) {
        throw IllegalArgumentException("invwishartLogPDF: cov is not a positive symmetric matrix.")
    }

    let n = x.getRows()
    let d = Float64(df)
    let p = Float64(n)
    let x_det = det(x)
    let cov_det = det(cov)

    let res1 = d * p * 0.5 * log(2.0)
    let res2 = multiGammaLog(d * 0.5, Int64(p))
    let res3 = d * 0.5 * log(cov_det)
    let res4 = - res1 - res2 + res3

    let res = - (d + p + 1.0) * 0.5 * log(x_det) - trace(cov * inv(x)) * 0.5 + res4

    return res
}

public func invwishartPDF(x: Matrix<Float64>, df: Int64, cov: Matrix<Float64>): Float64 {
    let resLog = invwishartLogPDF(x, df, cov)
    return exp(resLog)
}

@Test
public class TestInvwishart {
    @TestCase
    func testInvwishart(): Unit {
        let x = vector([0.24197072, 0.2186801, 0.17771369, 0.1375705])
        let res = invwishartLogPDF(x, 6, 1.0)
        let res_true = vector([0.83679977, 1.02154777, 1.32422305, 1.52738621])
        @Assert(approxEqual(res, res_true, atol:1e-6))

        let y = matrix([[2.0, 1.0], [1.0, 2.0]])
        let y_cov = matrix([[3.0, 2.0], [2.0, 3.0]])
        let y_res = invwishartLogPDF(y, 6, y_cov)
        @Assert(approxEqual(y_res, -7.157852972354762, atol:1e-6))
    }
}