/*
 * Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved.
 */
package matrix4cj

import std.math.*

/** LU Decomposition.
 * For an m-by-n matrix A with m >= n, the LU decomposition is an m-by-n
 * unit lower triangular matrix L, an n-by-n upper triangular matrix U,
 * and a permutation vector piv of length m so that A(piv,:) = L*U.
 * If m < n, then L is m-by-m and U is m-by-n.
 *
 * The LU decompostion with pivoting always exists, even if the matrix is
 * singular, so the constructor will never fail.  The primary use of the
 * LU decomposition is in the solution of square systems of simultaneous
 * linear equations.  This will fail if isNonsingular() returns false.
 */
public class LUDecomposition {
    /** Array for internal storage of decomposition. */
    private var LU: Matrix

    /**column dimension*/
    private var m: Int64 = 0

    /**row dimension*/
    private var n: Int64 = 0

    /**pivot sign*/
    private var pivsign: Int64 = 0

    /**Internal storage of pivot vector*/
    private var piv: Array<Int64>

    /** 
     * LU Decomposition Structure to access L, U and piv.
     * @param  A Rectangular matrix
     */
    public init(A: Matrix) {
        LU = A.clone()
        m = A.rowNum
        n = A.colNum
        piv = Array<Int64>(m) {_ => 0}
        for (i in 0..m) {
            piv[i] = i
        }
        pivsign = 1
        let LUcolj = Array<Float64>(m) {_ => 0.0}
        // Outer loop.
        for (j in 0..n) {
            // Make a copy of the j-th column to localize references.
            for (i in 0..m) {
                LUcolj[i] = LU[i, j]
            }
            // Apply previous transformations.
            for (i in 0..m) {
                // Most of the time is spent in the following dot product.
                let kmax = min(i, j)
                var s: Float64 = 0.0
                for (k in 0..kmax) {
                    s += LU[i, k] * LUcolj[k]
                }
                LUcolj[i] = LUcolj[i] - s
                LU[i, j] = LUcolj[i]
            }
            // Find pivot and exchange if necessary.
            var p = j
            for (i in (j + 1)..m) {
                if (abs(LUcolj[i]) > abs(LUcolj[p])) {
                    p = i
                }
            }
            if (p != j) {
                for (k in 0..n) {
                    let t = LU[p, k]
                    LU[p, k] = LU[j, k]
                    LU[j, k] = t
                }
                let k: Int64 = piv[p]
                piv[p] = piv[j]
                piv[j] = k
                pivsign = -pivsign
            }
            // Compute multipliers.
            if (j < m) {
                if (LU[j, j] != 0.0) {
                    for (i in (j + 1)..m) {
                        LU[i, j] /= LU[j, j]
                    }
                }
            }
        }
    }

    /** 
     * Is the matrix nonsingular?
     * @return     true if U, and hence A, is nonsingular.
     */
    public func isNonsingular(): Bool {
        for (j in 0..n) {
            if (LU[j, j] == 0.0) {
                return false
            }
        }
        return true
    }

    /** 
     * Return lower triangular factor
     * @return     L
     */
    public func getL(): Matrix {
        let X = Matrix(m, n, value: 0.0)
        for (i in 0..m) {
            for (j in 0..n) {
                if (i > j) {
                    X[i, j] = LU[i, j]
                } else if (i == j) {
                    X[i, j] = 1.0
                }
            }
        }
        return X
    }

    /** 
     * Return upper triangular factor
     * @return     U
     */
    public func getU(): Matrix {
        let X = Matrix(n, n, value: 0.0)
        for (i in 0..n) {
            for (j in 0..n) {
                if (i <= j) {
                    X[i, j] = LU[i, j]
                }
            }
        }
        return X
    }

    /** 
     * Return pivot permutation vector
     * @return     piv
     */
    public func getPivot(): Array<Int64> {
        let p = Array<Int64>(m) {_ => 0}
        for (i in 0..m) {
            p[i] = piv[i]
        }
        return p
    }

    /** 
     * Return pivot permutation vector as a one-dimensional double array
     * @return     (double) piv
     */
    public func getDoublePivot(): Array<Float64> {
        let vals = Array<Float64>(m) {_ => 0.0}
        for (i in 0..m) {
            vals[i] = Float64(piv[i])
        }
        return vals
    }

    /** 
     * Determinant
     * @return     det(A)
     * @exception  IllegalArgumentException  Matrix must be square
     */
    public func det(): Float64 {
        if (m != n) {
            throw IllegalArgumentException("Matrix must be square.")
        }
        var d = Float64(pivsign)
        for (j in 0..n) {
            d *= LU[j, j]
        }
        return d
    }

    /** 
     * Solve A*X = B
     * @param  B   A Matrix with as many rows as A and any number of columns.
     * @return     X so that L*U*X = B(piv,:)
     * @exception  IllegalArgumentException Matrix row dimensions must agree.
     * @exception  RuntimeException  Matrix is singular.
     */
    public func solve(B: Matrix): Matrix {
        if (B.rowNum != m) {
            throw IllegalArgumentException("Matrix row dimensions must agree.")
        }
        if (!this.isNonsingular()) {
            throw Matrix4cjException("Matrix is singular.")
        }
        let nx = B.colNum

        let Xmat = B[piv, 0..nx]
        for (k in 0..n) {
            for (i in (k + 1)..n) {
                for (j in 0..nx) {
                    Xmat[i, j] -= Xmat[k, j] * LU[i, k]
                }
            }
        }
        for (k in (n - 1)..=0 : -1) {
            for (j in 0..nx) {
                Xmat[k, j] /= LU[k, k]
            }
            for (i in 0..k) {
                for (j in 0..nx) {
                    Xmat[i, j] -= Xmat[k, j] * LU[i, k]
                }
            }
        }
        return Xmat
    }
}