package scientific.linear

import std.unittest.*
import std.unittest.testmacro.*

import scientific.numbers.*
import scientific.utils.assertEqual

/* Solve a system of linear equations for general matrix. */
foreign func LAPACKE_dgesv(
    matrix_layout: IntNative,  // row or column major
    n: IntNative,              // size of A
    nrhs: IntNative,           // number of right-hand sides
    a: CPointer<Unit>,         // matrix A (Float64)
    lda: IntNative,            // leading dimension of A
    ipiv: CPointer<Unit>,      // pivot indices (output, Int64)
    b: CPointer<Unit>,         // right side B (Float64)
    ldb: IntNative             // leading dimension of B
): Int64

/* Solve a system of linear equations for general matrix. */
foreign func LAPACKE_zgesv(
    matrix_layout: IntNative,  // row or column major
    n: IntNative,              // size of A
    nrhs: IntNative,           // number of right-hand sides
    a: CPointer<Unit>,         // matrix A (Complex64)
    lda: IntNative,            // leading dimension of A
    ipiv: CPointer<Unit>,      // pivot indices (output, Int64)
    b: CPointer<Unit>,         // right side B (Complex64)
    ldb: IntNative             // leading dimension of B
): Int64

public func solve(a: Matrix<Float64>, b: Matrix<Float64>): Matrix<Float64> {
    if (!a.isSquare() || a.getRows() != b.getRows()) {
        throw IllegalArgumentException("solve: dimension mismatch")
    }

    let n: Int64 = a.getRows()
    let nrhs: Int64 = b.getCols()
    let ipiv = zeros<Int64>(n)
    let acopy = cj_dcopy(a)
    let bcopy = cj_dcopy(b)
    let info = unsafe { LAPACKE_dgesv(LAPACK_COL_MAJOR, IntNative(n), IntNative(nrhs),
                        acopy.ptr, IntNative(n), ipiv.ptr, bcopy.ptr, IntNative(n)) }
    return bcopy
}

public func solve(a: Matrix<Complex64>, b: Matrix<Complex64>): Matrix<Complex64> {
    if (!a.isSquare() || a.getRows() != b.getRows()) {
        throw IllegalArgumentException("solve: dimension mismatch")
    }

    let n: Int64 = a.getRows()
    let nrhs: Int64 = b.getCols()
    let ipiv = zeros<Int64>(n)
    let acopy = cj_zcopy(a)
    let bcopy = cj_zcopy(b)
    let info = unsafe { LAPACKE_zgesv(LAPACK_COL_MAJOR, IntNative(n), IntNative(nrhs),
                        acopy.ptr, IntNative(n), ipiv.ptr, bcopy.ptr, IntNative(n)) }
    return bcopy
}

public func solve(a: Matrix<Float64>, b: Vector<Float64>): Vector<Float64> {
    return solve(a, b.toMatrix()).toVector()
}

public func solve(a: Matrix<Complex64>, b: Vector<Complex64>): Vector<Complex64> {
    return solve(a, b.toMatrix()).toVector()
}

/* Solve a system of linear equations with triangular matrix. */
foreign func LAPACKE_dtrtrs(
    matrix_layout: IntNative,  // row or column layout
    uplo: UInt8,               // upper or lower triangular
    trans: UInt8,              // whether to take transpose (hermitian transpose)
    diag: UInt8,               // whether diagonal is unit
    n: IntNative,              // size of A
    nrhs: IntNative,           // number of right-hand sides
    a: CPointer<Unit>,         // matrix A (Float64)
    lda: IntNative,            // leading dimension of A
    b: CPointer<Unit>,         // matrix B (Float64)
    ldb: IntNative             // leading dimension of B
): Int64

public func solveTriangular(a: Matrix<Float64>, b: Matrix<Float64>, lower: Bool): Matrix<Float64> {
    if (!a.isSquare() || a.getRows() != b.getRows()) {
        throw IllegalArgumentException("solve: dimension mismatch")
    }

    let LAPACK_COL_MAJOR: Int64 = 102
    let N: Int64 = a.getRows()
    let lda = N
    let ldb = N
    let nrhs: Int64 = b.getCols()
    var uplo: Rune = r'U'
    if (lower) {
        uplo = 'L'
    } else {
        uplo = 'U'
    }
    let trans: Rune = r'N'
    let diag: Rune = r'N'
    let cuplo: UInt32 = UInt32(uplo)
    let ctrans: UInt32 = UInt32(trans)
    let cdiag: UInt32 = UInt32(diag)
    let acopy = a.copy()
    let bcopy = b.copy()
    let info = unsafe{ LAPACKE_dtrtrs(IntNative(LAPACK_COL_MAJOR), 
                        UInt8(cuplo),UInt8(ctrans),UInt8(cdiag),IntNative(N), 
                         IntNative(nrhs), acopy.ptr, 
                        IntNative(lda), bcopy.ptr, IntNative(ldb)) }
    return bcopy
}

public func solveTriangular(a: Matrix<Float64>, b: Vector<Float64>, lower: Bool): Vector<Float64> {
    return solveTriangular(a, b.toMatrix(), lower).toVector()
}

/* Solve system of linear equations with banded matrix. */
foreign func LAPACKE_dgbsv(
    matrix_layout: IntNative,  // row or column major
    n: IntNative,              // size of A
    kl: IntNative,             // number of subdiagonals
    ku: IntNative,             // number of superdiagonals
    nrhs: IntNative,           // number of right-hand sides
    ab: CPointer<Unit>,        // matrix A (band storage, Float64)
    ldab: IntNative,           // leading dimension of A (band storage)
    ipiv: CPointer<Unit>,      // pivot indices (output, Int64)
    b: CPointer<Unit>,         // right-side B (overwritten, Float64)
    ldb: IntNative             // leading dimension of B
): Int64

public func solveBanded(kl: Int64, ku: Int64, A: Matrix<Float64>, b: Matrix<Float64>): Matrix<Float64> {
    if (!A.isSquare() || A.getRows() != b.getRows()) {
        throw IllegalArgumentException("solve: dimension mismatch")
    }

    let LAPACK_COL_MAJOR: Int64 = 102
    let N: Int64 = A.getRows()
    let nrhs: Int64 = b.getCols()
    let ldA: Int64 = 2 * kl + ku + 1
    let ldb = N
    var ipiv = zeros<Int64>(N)
    let Acopy = A.copy()
    let bcopy = b.copy()
    let info = unsafe { LAPACKE_dgbsv(IntNative(LAPACK_COL_MAJOR), IntNative(N), 
                        IntNative(kl), IntNative(ku), IntNative(nrhs), Acopy.ptr, 
                        IntNative(ldA), ipiv.ptr, bcopy.ptr, IntNative(ldb)) }
    return bcopy
}

public func solveBanded(kl: Int64, ku: Int64, A: Matrix<Float64>, b: Vector<Float64>): Vector<Float64> {
    return solveBanded(kl, ku, A, b.toMatrix()).toVector()
}

/* Solve system of linear equations with symmetric positive definite band matrix. */
foreign func LAPACKE_dpbsv(
    matrix_layout: IntNative,  // row or column major
    uplo: UInt8,               // upper or lower triangular
    n: IntNative,              // size of A
    kd: IntNative,             // number of subdiagonals
    nrhs: IntNative,           // number of right-hand sides
    ab: CPointer<Unit>,        // matrix A (band storage, Float64)
    ldab: IntNative,           // leading dimension of A (band storage)
    b: CPointer<Unit>,         // right-side B (overwritten, Float64)
    ldb: IntNative             // leading dimension of B
): Int64

public func solveHBanded(AB: Matrix<Float64>, b: Matrix<Float64>, lower!:Bool = false): Matrix<Float64> {
    let LAPACK_COL_MAJOR: Int64 = 102
    var uplo: UInt8
    if (lower) {
        uplo = LAPACK_LOWER_TRIANGULAR
    } else {
        uplo = LAPACK_UPPER_TRIANGULAR
    }
    let N: Int64 = AB.getCols()
    let kd: Int64 = AB.getRows() - 1
    let nrhs: Int64 = b.getCols()
    let ldab: Int64 = kd + 1
    let ldb = N
    let ABcopy = AB.copy()
    let bcopy = b.copy()
    let info = unsafe { LAPACKE_dpbsv(IntNative(LAPACK_COL_MAJOR), uplo, IntNative(N), 
                        IntNative(kd), IntNative(nrhs), ABcopy.ptr, 
                        IntNative(ldab), bcopy.ptr, IntNative(ldb)) }
    return bcopy
}

public func solveHBanded(A: Matrix<Float64>, b: Vector<Float64>, lower!:Bool = false): Vector<Float64> {
    return solveHBanded(A, b.toMatrix(), lower:lower).toVector()
}

func testDgesv() {
    let LAPACK_COL_MAJOR: Int64 = 102
    let N: Int64 = 3
    let lda = N
    let ldb = N
    let nrhs: Int64 = 1
    var A = matrix<Float64>(
        [[1.0, 2.0, 3.0],
         [4.0, 5.0, 6.0],
         [7.0, 8.0, 0.0]]
    )
    var B = vector<Float64>(
        [6.0, 15.0, 15.0]
    )
    var ipiv = zeros<Int64>(N)
    let info = unsafe { LAPACKE_dgesv(IntNative(LAPACK_COL_MAJOR), IntNative(N), IntNative(nrhs),
                        A.ptr, IntNative(lda), ipiv.ptr, B.ptr, IntNative(ldb)) }
    assertEqual(info, 0)
    
    var expected = vector<Float64>(
        [1.0, 1.0, 1.0]
    )
    assertApproxEqual(B, expected)
}

func testDgbsv() {
    let LAPACK_COL_MAJOR: Int64 = 102
    var A = matrix<Float64>(
        [[5.0, 2.0, -1.0, 0.0, 0.0],
         [1.0, 4.0, 2.0, -1.0, 0.0],
         [0.0, 1.0, 3.0, 2.0, -1.0],
         [0.0, 0.0, 1.0, 2.0, 2.0],
         [0.0, 0.0, 0.0, 1.0, 1.0]]
    )

    var AB = matrix<Float64>(
        [[0.0, 0.0, 0.0, 0.0, 0.0],
         [0.0, 0.0, -1.0, -1.0, -1.0],
         [0.0, 2.0, 2.0, 2.0, 2.0],
         [5.0, 4.0, 3.0, 2.0, 1.0],
         [1.0, 1.0, 1.0, 1.0, 0.0]]
    )
    var B = matrix<Float64>(
        [[0.0, 1.0, 2.0, 2.0, 3.0]]
    )
    
    let kl: Int64 = 1
    let ku: Int64 = 2
    let N: Int64 = 5
    var ipiv = zeros<Int64>(N)
    let nrhs: Int64 = 1
    let ldab: Int64 = 2 * kl + ku + 1
    let ldb = N
    let info = unsafe{ LAPACKE_dgbsv(IntNative(LAPACK_COL_MAJOR), IntNative(N), 
                        IntNative(kl), IntNative(ku), IntNative(nrhs), AB.ptr, 
                        IntNative(ldab), ipiv.ptr, B.ptr, IntNative(ldb)) }
    
    assertEqual(info, 0)
}

func testDtrtrs() {
    let LAPACK_COL_MAJOR: Int64 = 102
    var a = matrix<Float64>(
        [[3.0, 0.0, 0.0, 0.0],
         [2.0, 1.0, 0.0, 0.0],
         [1.0, 0.0, 1.0, 0.0],
         [1.0, 1.0, 1.0, 1.0]]
    )
    var b = matrix<Float64>(
        [[4.0, 2.0, 4.0, 2.0]]
    )

    let N: Int64 = 4
    let lda = N
    let ldb = N
    let nrhs: Int64 = 1
    let uplo: Rune = r'L'
    let trans: Rune = r'N'
    let diag: Rune = r'N'
    let cuplo: UInt32 = UInt32(uplo)
    let ctrans: UInt32 = UInt32(trans)
    let cdiag: UInt32 = UInt32(diag)
    let info = unsafe { LAPACKE_dtrtrs(IntNative(LAPACK_COL_MAJOR), 
                        UInt8(cuplo),UInt8(ctrans),UInt8(cdiag),IntNative(N), 
                        IntNative(nrhs), a.ptr, 
                        IntNative(lda), b.ptr, IntNative(ldb)) }
    assertEqual(info, 0)
}

func testDpbsv() {
    var AB = matrix<Float64>(
        [[ 4.0,  5.0,  6.0,  7.0,  8.0,  9.0],
         [ 2.0,  2.0,  2.0,  2.0,  2.0,  0.0],
         [-1.0, -1.0, -1.0, -1.0,  0.0,  0.0]]
    )
    var B = vector<Float64>(
        [1.0, 2.0, 2.0, 3.0, 3.0, 3.0]
    )
    
    let kd: Int64 = 2
    let N: Int64 = 6
    let nrhs: Int64 = 1
    let ldab: Int64 = kd + 1
    let ldb = N
    let info = unsafe{ LAPACKE_dpbsv(IntNative(LAPACK_COL_MAJOR), LAPACK_LOWER_TRIANGULAR, IntNative(N), 
                        IntNative(kd), IntNative(nrhs), AB.ptr, 
                        IntNative(ldab), B.ptr, IntNative(ldb)) }
    
    let expected_x = vector<Float64>(
        [0.034314, 0.459384, 0.056022, 0.477591, 0.175770, 0.347339]
    )
    assertApproxEqual(B, expected_x, atol:1e-6)
}

@Test
public class TestSolve {
    @TestCase
    func testLapackDpbsv(): Unit {
        testDpbsv()
    }

    @TestCase
    func testLapackDgesv(): Unit {
        testDgesv()
    }

    @TestCase
    func testLapackDgbsv(): Unit {
        testDgbsv()
    }

    @TestCase
    func testLapackDtrtrs(): Unit {
        testDtrtrs()
    }

    @TestCase
    func testSolve(): Unit {
        var A = matrix<Float64>(
            [[1.0, 2.0, 3.0],
             [4.0, 5.0, 6.0],
             [7.0, 8.0, 0.0]]
        )
        var b = vector<Float64>(
            [6.0, 15.0, 15.0]
        )
        var x = solve(A, b)
        var expected = vector<Float64>(
            [1.0, 1.0, 1.0]
        )
        @Assert(approxEqual(x, expected))
    }

    @TestCase
    func testSolveComplex64(): Unit {
        var A = matrix<Complex64>([
            [Complex64(0.0, 1.0), Complex64(0.0, 2.0), Complex64(0.0, 3.0)],
            [Complex64(0.0, 4.0), Complex64(0.0, 5.0), Complex64(0.0, 6.0)],
            [Complex64(0.0, 7.0), Complex64(0.0, 8.0), Complex64(0.0, 0.0)]
        ])
        let b = vector<Complex64>([
            Complex64(0.0, 6.0), Complex64(0.0, 15.0), Complex64(0.0, 15.0)
        ])
        var x: Vector<Complex64> = solve(A, b)
        let expected = vector<Complex64>([
            Complex64(1.0, 0.0), Complex64(1.0, 0.0), Complex64(1.0, 0.0)
        ])
        @Assert(approxEqualC(x, expected))
    }

    @TestCase
    func testSolveBanded(): Unit {
        var A = matrix<Float64>(
            [[0.0, 0.0, 0.0, 0.0, 0.0],
             [0.0, 0.0, -1.0, -1.0, -1.0],
             [0.0, 2.0, 2.0, 2.0, 2.0],
             [5.0, 4.0, 3.0, 2.0, 1.0],
             [1.0, 1.0, 1.0, 1.0, 0.0]]
        )
        var b = vector<Float64>(
            [0.0, 1.0, 2.0, 2.0, 3.0]
        )
        let x = solveBanded(1, 2, A, b)
        let expected_x = vector<Float64>(
            [-2.372881, 3.932203, -4.000000, 4.355932, -1.355932]
        )
        @Assert(approxEqual(x, expected_x, atol:1e-6))
    }

    @TestCase
    func testSolveTriangular(): Unit {
        var a = matrix<Float64>(
            [[3.0, 0.0, 0.0, 0.0],
             [2.0, 1.0, 0.0, 0.0],
             [1.0, 0.0, 1.0, 0.0],
             [1.0, 1.0, 1.0, 1.0]]
        )
        var b = vector<Float64>(
            [4.0, 2.0, 4.0, 2.0]
        )
        var lower: Bool = true
        let x = solveTriangular(a, b, lower)
        let expected_x = vector<Float64>(
            [1.333333, -0.666667, 2.666667, -1.333333]
        )
        @Assert(approxEqual(x, expected_x, atol:1e-6))
    }

    @TestCase
    func testSolveHBanded(): Unit {
        var AB = matrix<Float64>(
            [[ 4.0,  5.0,  6.0,  7.0,  8.0,  9.0],
             [ 2.0,  2.0,  2.0,  2.0,  2.0,  0.0],
             [-1.0, -1.0, -1.0, -1.0,  0.0,  0.0]]
        )
        var B = vector<Float64>(
            [1.0, 2.0, 2.0, 3.0, 3.0, 3.0]
        )
        let x = solveHBanded(AB, B, lower:true)
        
        let expected_x = vector<Float64>(
            [0.034314, 0.459384, 0.056022, 0.477591, 0.175770, 0.347339]
        )
        @Assert(approxEqual(x, expected_x, atol:1e-6))
    }
}