package scientific.linear

import std.math.*
import std.unittest.*
import std.unittest.testmacro.*

import scientific.numbers.*

/* Compute the singular value decomposition (SVD) of a real M-by-N
   matrix A.
*/
foreign func LAPACKE_dgesvd(
    matrix_layout: IntNative,   // row or column major
    jobu: UInt8,           // computing all or part of U ('a' for all)
    jobvt: UInt8,          // computing all or part of transpose(V) ('a' for all)
    m: IntNative,          // number of rows in A
    n: IntNative,          // number of columns in A
    a: CPointer<Unit>,     // the m-by-n matrix (Float64).
    lda: IntNative,        // leading dimension of A
    s: CPointer<Unit>,     // output - singular values (Float64)
    u: CPointer<Unit>,     // output - matrix U (Float64)
    ldu: IntNative,        // leading dimension of U
    vt: CPointer<Unit>,    // output - matrix transpose(V) (Float64)
    ldvt: IntNative,       // leading dimension of transpose(V)
    superb: CPointer<Unit>):
    Int64

/* High-level function. Given input matrix A, return the tuple
   (U, S, transpose(V)).
*/
public func svd(a: Matrix<Float64>, full_matrices!:Bool = true): (Matrix<Float64>, Vector<Float64>, Matrix<Float64>) {
    let LAPACK_COL_MAJOR: Int64 = 102
    var jobu: UInt32
    var jobvt: UInt32
    if (full_matrices) {
        jobu = UInt32(r'a')
        jobvt = UInt32(r'a')
    } else {
        jobu = UInt32(r's')
        jobvt = UInt32(r's')
    }
    let acopy = a.copy()
    let m = acopy.getRows()
    let n = acopy.getCols()
    let minmn = min(m, n)
    let lda: Int64 = m
    let s = zeros<Float64>(minmn)
    let ldu: Int64 = m
    var u: Matrix<Float64>
    var ldvt: Int64
    var vt: Matrix<Float64>
    if (full_matrices) {
        u = zeros<Float64>(ldu,m)
        ldvt = n
        vt = zeros<Float64>(n, n)
    } else {
        u = zeros<Float64>(ldu, minmn)
        ldvt = minmn
        vt = zeros<Float64>(minmn, n)
    }
    let superb = zeros<Float64>(minmn)
    let info = unsafe {
        LAPACKE_dgesvd(IntNative(LAPACK_COL_MAJOR), UInt8(jobu), UInt8(jobvt), IntNative(m),
            IntNative(n), acopy.ptr, IntNative(lda), s.ptr, u.ptr, IntNative(ldu), vt.ptr,
            IntNative(ldvt), superb.ptr)
    }
    return (u, s, vt)
}

public func svdvals(a: Matrix<Float64>): Vector<Float64> {
    var (u, s, vt) = svd(a)
    return s
}

public func diagsvd(s: Vector<Float64>, m: Int64, n: Int64): Matrix<Float64> {
    return diag(s, m, n)
}

public func orth(a: Matrix<Float64>, rcond!:Float64 = 0.0): Matrix<Float64> {
    let m = a.getRows()
    let n = a.getCols()

    // Obtain SVD
    var (u, s, vt) = svd(a)

    // Set relative condition number and tolerance
    var tol: Float64
    if (rcond == 0.0) {
        tol = max(s) * Float64(max(m, n)) * Float64.eps()
    } else {
        tol = max(s) * rcond
    }

    // Count the number of nonzero singular values
    var num: Int64 = 0
    for (i in 0..s.size()) {
        if (s[i] > tol) {
            num++;
        }
    }

    // Obtain the columns of U corresponding to nonzero singular values
    let ur = u.getRows()
    let uc = u.getCols()
    var res = zeros<Float64>(ur, num)
    for (i in 0..ur) {
        for (j in 0..num) {
            res[i,j] = u[i,j]
        }
    }
    return res
}

public func nullSpace(a: Matrix<Float64>, rcond!:Float64 = 0.0): Matrix<Float64> {
    let m = a.getRows()
    let n = a.getCols()

    // Obtain SVD
    var (u, s, vt) = svd(a)

    // Set relative condition number and tolerance
    var tol: Float64
    if (rcond == 0.0) {
        tol = max(s) * Float64(max(m, n)) * Float64.eps()
    } else {
        tol = max(s) * rcond
    }

    // Count the number of nonzero singular values
    var num: Int64 = 0
    for (i in 0..s.size()) {
        if (s[i] > tol) {
            num++;
        }
    }

    // Obtain the columns of V^T corresponding to zero singular values
    let vtr = vt.getRows()
    let vtc = vt.getCols()
    var nst = zeros<Float64>(vtr - num, vtc)
    for (i in 0..vtr-num) {
        for (j in 0..vtc) {
            nst[i,j] = vt[i+num,j]
        }
    }
    return nst.transpose()
}

public func polar(a: Matrix<Float64>): (Matrix<Float64>, Matrix<Float64>) {
    var (w, s, vh) = svd(a, full_matrices: false)
    var u = w * vh

    var p = mat_multiply_vec(vh.transpose(), s) * vh
    return (u, p)
}

public func mat_multiply_vec(a: Matrix<Float64>, b: Vector<Float64>) {
    var acopy = a.copy()
    for (i in 0..acopy.getRows()) {
        for (j in 0..acopy.getCols()) {
            acopy[i,j] = acopy[i,j] * b[j]
        }
    }
    return acopy
}

public func pinv(a: Matrix<Float64>, atol!:Float64 = 0.0, rtol!:Float64 = 0.0): Matrix<Float64> {
    let m = a.getRows()
    let n = a.getCols()
    let minmn = min(m, n)

    // Compute SVD
    var (u, s, vt) = svd(a, full_matrices: false)

    // Compute tolerance
    var tol: Float64
    if (rtol == 0.0) {
        tol = atol + max(s) * Float64(max(m, n)) * Float64.eps()
    } else {
        tol = atol + max(s) * rtol
    }

    var t = zeros<Float64>(minmn, minmn)
    for (i in 0..minmn) {
        if (s[i] > tol) {
            t[i,i] = 1.0 / s[i]
        }
    }

    // Final result is (u * t * v^T)^T
    return (u * t * vt).transpose()
}

public func subspaceAngles(A: Matrix<Float64>, B: Matrix<Float64>): Vector<Float64> {
    if (A.getRows() != B.getRows()) {
        throw IllegalArgumentException("subspaceAngles: A and B must have the same number of rows")
    }

    // Compute orthonormal bases of column-spaces
    let QA = orth(A)
    let QB = orth(B)

    // Compute SVD for cosine
    let QA_H_QB = QA.transpose() * QB
    let sigma: Vector<Float64> = svdvals(QA_H_QB).clip(-1.0, 1.0)

    // Compute matrix B2
    var B2: Matrix<Float64>
    if (A.getCols() >= B.getCols()) {
        B2 = QB - QA * QA_H_QB
    } else {
        B2 = QA - QB * QA_H_QB.transpose()
    }
    let sigma2: Vector<Float64> = svdvals(B2).clip(-1.0, 1.0)

    // Compute final answer
    let c = min(A.getCols(), B.getCols())
    let res = empty<Float64>(c)
    for (i in 0..c) {
        if (pow(sigma[i], 2.0) >= 0.5) {
            res[i] = asin(sigma2[i])
        } else {
            res[i] = acos(sigma[c-i-1])
        }
    }
    return res
}

public func lstsq(A: Matrix<Float64>, b: Vector<Float64>, cond!:Float64 = 0.0): Vector<Float64> {
    if (A.getRows() != b.size()) {
        throw IllegalArgumentException("lstsq: dimension mismatch")
    }
    let pinvA = pinv(A, rtol:cond)
    return pinvA * b
}


/* Test low-level function */
func testDgesvd() {
    let LAPACK_COL_MAJOR: Int64 = 102
    var a = matrix<Float64>(
        [[1.0, 2.0, 3.0],
         [4.0, 5.0, 6.0]]
    )
    let jobu: Rune = r'a'
    let jobvt: Rune = r'a'
    let cjobu: UInt32 = UInt32(jobu)
    let cjobvt: UInt32 = UInt32(jobvt)
    let m = a.getRows()
    let n = a.getCols()
    let lda: Int64 = m
    var s = zeros<Float64>(min(m,n))
    let ldu: Int64 = m
    var u = zeros<Float64>(ldu,m)
    let ldvt: Int64 = n
    var vt = zeros<Float64>(ldvt,n)
    var superb = zeros<Float64>(min(m,n))
    let info = unsafe { LAPACKE_dgesvd(IntNative(LAPACK_COL_MAJOR), UInt8(cjobu), UInt8(cjobvt), IntNative(m), 
                        IntNative(n), a.ptr, IntNative(lda), s.ptr, u.ptr, IntNative(ldu), vt.ptr, IntNative(ldvt), superb.ptr) }

    var expected_u = matrix<Float64>(
        [[-0.386318, -0.922366],
         [-0.922366, 0.386318]]
    )
    assertApproxEqual(u, expected_u, atol:1e-6)

    var expected_s = vector<Float64>(
        [9.508032, 0.772870]
    )
    assertApproxEqual(s, expected_s, atol:1e-6)

    var expected_vt = matrix<Float64>(
        [[-0.428667, -0.566307, -0.703947],
         [0.805964, 0.112382, -0.581199],
         [0.408248, -0.816497, 0.408248]]
    )
    assertApproxEqual(vt, expected_vt, atol:1e-6)
}

@Test
public class TestSVD {
    @TestCase
    func testLapackDgesvd(): Unit {
        testDgesvd()
    }

    @TestCase
    func testPinv1(): Unit {
        var A = matrix<Float64>(
            [[1.0, 2.0, 3.0],
             [4.0, 5.0, 6.0]]
        )
        let pinvA = pinv(A)
        var expected_pinvA = matrix<Float64>(
            [[-0.944444, 0.444444],
             [-0.111111, 0.111111],
             [0.722222, -0.222222]]
        )
        @Assert(approxEqual(pinvA, expected_pinvA, atol:1e-6))
        @Assert(approxEqual(A, A * pinvA * A, atol:1e-6))
    }

    @TestCase
    func testPinv2(): Unit {
        var A = matrix<Float64>(
            [[1.0, 0.0, 0.0],
             [0.0, 1.0, 0.0],
             [0.0, 0.0, 0.0]]
        )
        let pinvA = pinv(A)
        @Assert(approxEqual(pinvA, A, atol:1e-6))
    }

    @TestCase
    func testSvd1(): Unit {
        var a = matrix<Float64>(
            [[1.0, 2.0, 3.0],
             [4.0, 5.0, 6.0]]
        )
        var (u, s, vt) = svd(a)
        var expected_u = matrix<Float64>(
            [[-0.386318, -0.922366],
             [-0.922366, 0.386318]]
        )
        var expected_s = vector<Float64>(
            [9.508032, 0.772870]
        )
        var expected_vt = matrix<Float64>(
            [[-0.428667, -0.566307, -0.703947],
             [0.805964, 0.112382, -0.581199],
             [0.408248, -0.816497, 0.408248]]
        )
        @Assert(approxEqual(u, expected_u, atol:1e-6))
        @Assert(approxEqual(s, expected_s, atol:1e-6))
        @Assert(approxEqual(vt, expected_vt, atol:1e-6))
    }

    @TestCase
    func testSvd2(): Unit {
        var a = matrix<Float64>(
            [[0.5, 1.0, 2.0],
             [1.5, 3.0, 4.0]]
        )
        var (u, s, vt) = svd(a, full_matrices: false)
        var expected_u = matrix<Float64>(
            [[-0.397854, -0.917449],
             [-0.917449, 0.397854]]
        )
        var expected_s = vector<Float64>(
            [5.687303, 0.393168]
        )
        var expected_vt = matrix<Float64>(
            [[-0.276950, -0.553901, -0.785171],
             [0.351139, 0.702278, -0.619280]]
        )
        @Assert(approxEqual(u, expected_u, atol:1e-6))
        @Assert(approxEqual(s, expected_s, atol:1e-6))
        @Assert(approxEqual(vt, expected_vt, atol:1e-6))
    }

    @TestCase
    func testSvdvals(): Unit {
        var a = matrix<Float64>(
            [[1.0, 2.0, 3.0],
             [4.0, 5.0, 6.0]]
        )
        var s = svdvals(a)
        var expected_s = vector<Float64>(
            [9.508032, 0.772870]
        )
        @Assert(approxEqual(s, expected_s, atol:1e-6))
    }

    @TestCase
    func testDiagsvd(): Unit {
        var a = vector<Float64>(
            [1.0, 2.0, 3.0]
        )
        var result = diagsvd(a, 3, 4)
        var expected_result = matrix<Float64>(
            [[1.0, 0.0, 0.0, 0.0],
             [0.0, 2.0, 0.0, 0.0],
             [0.0, 0.0, 3.0, 0.0]]
        )
        @Assert(approxEqual(result, expected_result, atol:1e-6))
    }

    @TestCase
    func testOrth1(): Unit {
        var A = matrix<Float64>(
            [[2.0, 0.0, 0.0],
             [0.0, 5.0, 0.0]]
        )
        assertApproxEqual(orth(A), matrix<Float64>(
            [[0.0, 1.0],
             [1.0, 0.0]]
        ))
        @Assert(approxEqual(orth(A.transpose()), matrix<Float64>(
            [[0.0, 1.0],
            [1.0, 0.0],
            [0.0, 0.0]]
        )))
    }

    @TestCase
    func testOrth2(): Unit {
        var A = matrix<Float64>(
            [[ 1.0,  0.0,  1.0],
             [-1.0, -2.0,  0.0],
             [ 0.0,  1.0, -1.0]]
        )
        var result = orth(A)
        var expected_result = matrix<Float64>(
            [[-0.120000, -0.809712, 0.574427],
             [ 0.901753,  0.153123, 0.404222],
             [-0.415261,  0.566498, 0.711785]]
        )
        @Assert(approxEqual(result, expected_result, atol:1e-6))
    }

    @TestCase
    func testNullspace1(): Unit {
        var A = matrix<Float64>([[1.0, 1.0], [1.0, 1.0]])
        var ns = nullSpace(A)
        if (ns[0,0] < 0.0) {
            ns = -1.0 * ns   // remove sign ambiguity
        }
        var expected_ns = matrix<Float64>([[0.707107], [-0.707107]])
        @Assert(approxEqual(ns, expected_ns, atol:1e-6))
    }

    @TestCase
    func testNullspace2(): Unit {
        var A = matrix<Float64>(
            [[0.0, 1.0, 2.0, 4.0],
             [1.0, 2.0, 3.0, 4.0]]
        )
        var ns = nullSpace(A)
        for (i in 0..ns.getCols()) {
            if (ns[0,i] < 0.0) {
                ns.setCol(i, -1.0 * ns.getCol(i))
            }
        }
        var expected_ns = matrix<Float64>(
            [[ 0.486665,  0.611775],
             [ 0.279469, -0.767179],
             [-0.766134,  0.155404],
             [ 0.313200,  0.114093]]
        )
        @Assert(approxEqual(ns, expected_ns, atol:1e-6))
    }

    @TestCase
    func testPolar1(): Unit {
        // Example of square matrix
        var A = matrix<Float64>([[1.0, -1.0], [2.0, 4.0]])
        var (U, P) = polar(A)
        var expected_U = matrix<Float64>(
            [[0.857493, -0.514496],
             [0.514496,  0.857493]]
        )
        var expected_P = matrix<Float64>(
            [[1.886484, 1.200490],
             [1.200490, 3.944467]]
        )
        @Assert(approxEqual(U, expected_U, atol:1e-6))
        @Assert(approxEqual(P, expected_P, atol:1e-6))
    }

    @TestCase
    func testPolar2(): Unit {
        // Example of matrix with m < n
        var A = matrix<Float64>([[0.5, 1.0, 2.0], [1.5, 3.0, 4.0]])
        var (U, P) = polar(A)
        var expected_U = matrix<Float64>(
            [[-0.211966, -0.423932, 0.880541],
             [ 0.393790,  0.787579, 0.473971]]
        )
        var expected_P = matrix<Float64>(
            [[0.484701, 0.969403, 1.151226],
             [0.969403, 1.938806, 2.302453],
             [1.151226, 2.302453, 3.656964]]
        )
        @Assert(approxEqual(U, expected_U, atol:1e-6))
        @Assert(approxEqual(P, expected_P, atol:1e-6))
    }

    @TestCase
    func testPolar3(): Unit {
        // Example of matrix with m > n
        var A = matrix<Float64>([[0.5, 1.5], [1.0, 3.0], [2.0, 4.0]])
        var (U, P) = polar(A)
        var expected_U = matrix<Float64>(
            [[-0.211966, 0.393790],
             [-0.423932, 0.787579],
             [ 0.880541, 0.473971]]
        )
        var expected_P = matrix<Float64>(
            [[1.231166, 1.932416],
             [1.932416, 4.849306]]
        )
        @Assert(approxEqual(U, expected_U, atol:1e-6))
        @Assert(approxEqual(P, expected_P, atol:1e-6))
    }

    @TestCase
    func testSubspaceAngles1(): Unit {
        var A = matrix<Float64>(
            [[1.0,  1.0],
             [1.0, -1.0],
             [1.0,  1.0],
             [1.0, -1.0]]
        )
        var B = matrix<Float64>(
            [[ 1.0,  1.0],
             [ 1.0, -1.0],
             [-1.0, -1.0],
             [-1.0,  1.0]]
        )
        let angles = subspaceAngles(A, B)
        @Assert(approxEqual(angles, vector([1.570796, 1.570796]), atol:1e-6))
    }

    @TestCase
    func testSubspaceAngles2(): Unit {
        var A = matrix<Float64>(
            [[1.0,  1.0],
             [1.0, -1.0],
             [1.0,  1.0],
             [1.0, -1.0]]
        )
        let angles = subspaceAngles(A, A)
        @Assert(approxEqual(angles, vector([0.0, 0.0]), atol:1e-6))
    }

    @TestCase
    func testLstsq1(): Unit {
        let x = vector([1.0, 2.5, 3.5, 4.0, 5.0, 7.0, 8.5])
        let y = vector([0.3, 1.1, 1.5, 2.0, 3.2, 6.6, 8.6])
        let M = empty<Float64>(x.size(), 2)
        for (i in 0..x.size()) {
            M[i,0] = 1.0
            M[i,1] = pow(x[i], 2.0)
        }
        let p = lstsq(M, y)
        let expected_p = vector([0.209258, 0.120139])
        @Assert(approxEqual(p, expected_p, atol:1e-6))
    }
}