eaebec39创建于 2024年10月24日历史提交
package scientific.stats.multivariate

import std.math.*
import std.unittest.*
import std.unittest.testmacro.*

import scientific.numbers.*
import scientific.linear.*
import scientific.stats.random.*

/*
 *  Probability density function.
 */
public func matrixNormalPDF(x: Matrix<Float64>, mean: Matrix<Float64>, rowcov: Matrix<Float64>, colcov: Matrix<Float64>): Float64 {
    if (x.shape() != mean.shape() || x.getRows() != rowcov.getRows() ||
        x.getCols() != colcov.getRows() || !rowcov.isSquare() || !colcov.isSquare()) {
        throw IllegalArgumentException("matrixNormalPDF: parameters dimensions incompatible.") 
    }

    let x_flatten = flatten(x.transpose())
    let mean_flatten = flatten(mean.transpose())
    let cov = kron(colcov, rowcov)
    return multiNormalPDF(x_flatten, mean_flatten, cov)
}

/*
 *  Log Probability density function.
 */
public func matrixNormalLogPDF(x: Matrix<Float64>, mean: Matrix<Float64>, rowcov: Matrix<Float64>, colcov: Matrix<Float64>): Float64 {
    if (x.shape() != mean.shape() || x.getRows() != rowcov.getRows() ||
        x.getCols() != colcov.getRows() || !rowcov.isSquare() || !colcov.isSquare()) {
        throw IllegalArgumentException("matrixNormalLogPDF: parameters dimensions incompatible.") 
    }

    return log(matrixNormalPDF(x, mean, rowcov, colcov))
}


@Test
public class TestMatrixNormal {
    @TestCase
    func testMatrixNormal(): Unit {
        let x = matrix([[0.1, 1.1], [2.1, 3.1], [4.1, 5.1]])
        let mean = matrix([[0.0, 1.0], [2.0, 3.0], [4.0, 5.0]])
        let u = diag(vector([1.0, 2.0, 3.0]))
        let v = diag(vector([0.3, 0.3]))
        let res = matrixNormalPDF(x, mean, u, v)
        @Assert(approxEqual(res, 0.023410202050005054, atol:1e-6))
    }
}