eaebec39创建于 2024年10月24日历史提交
package scientific.stats.multivariate

import std.math.*
import std.unittest.*
import std.unittest.testmacro.*

import scientific.numbers.*
import scientific.linear.*
import scientific.stats.random.*

// TODO: case where X is a matrix, n is a vector.

/*
 *  Log Probability density function.
 */
public func multinomialLogPMF(x: Vector<Int64>, n: Int64, p: Vector<Float64>): Float64 {
    if (x.size() != p.size()) {
        throw IllegalArgumentException("multinomialLogPDF: parameters dimensions incompatible.")
    }

    var sum_x = 0
    var sum_p = 0.0
    let k = x.size()
    for (i in 0..k) {
        sum_x += x[i]
        sum_p += p[i]
    }

    if (sum_x != n) {
        throw IllegalArgumentException("multinomialLogPDF: the sum of x incompatible with n.")
    }
    if (sum_p != 1.0) {
        throw IllegalArgumentException("multinomialLogPDF: the sum of p not equal to 1.0.")
    }

    var res1 = 0.0
    var res2 = 0.0
    var res3 = 0.0

    res1 = supMultinomial(n)
    for (i in 0..k) {
        res2 += supMultinomial(x[i])
        res3 += Float64(x[i]) * log(p[i])
    }

    return res1 - res2 + res3
}


/*
 *  Probability density function.
 */
public func multinomialPMF(x: Vector<Int64>, n: Int64, p: Vector<Float64>): Float64 {
    if (x.size() != p.size()) {
        throw IllegalArgumentException("multinomialPMF: parameters dimensions incompatible.")
    }

    var sum_x = 0
    var sum_p = 0.0
    let k = x.size()
    for (i in 0..k) {
        sum_x += x[i]
        sum_p += p[i]
    }

    if (sum_x != n) {
        throw IllegalArgumentException("multinomialPMF: the sum of x incompatible with n.")
    }
    if (sum_p != 1.0) {
        throw IllegalArgumentException("multinomialPMF: the sum of p not equal to 1.0.")
    }

    return multinomialLogPMF(x, n, p)
}

/*
 * Compute the sum from log(1) to log(n)
 */
func supMultinomial(n: Int64): Float64 {
    var res = 0.0
    for (i in 1..(n+1)) {
        res += log(Float64(i))
    }
    return res
}


@Test
public class TestMultinomial {
    @TestCase
    func testMultinomial(): Unit {
        let x = vector<Int64>([1, 3, 4])
        let p = vector<Float64>([0.3, 0.2, 0.5])
        @Assert(approxEqual(multinomialLogPMF(x, 8, p), -3.17008566, atol:1e-6))
    }
}