package scientific.linear

import std.unittest.*
import std.unittest.testmacro.*

import scientific.numbers.*

foreign func LAPACKE_dgees(
    matrix_layout: IntNative,  // row or column major
    jobvs: UInt8,              // side to multiply Q ('l' or 'r')
    sort: UInt8,               // whether to apply transpose ('n' or 't')
    select: Bool,              // number of rows in A
    n: IntNative,              // number of columns in A
    a: CPointer<Unit>,         // the m-by-n matrix (output - elementary reflectors, Float64)
    lda: IntNative,            // leading dimension of A
    sdim: CPointer<Unit>,      // Int64
    wr: CPointer<Unit>,        // Float64
    wi: CPointer<Unit>,        // Float64
    vs: CPointer<Unit>,        // Float64
    ldvs: IntNative     // output - scalar factors of elementary reflectors
): Int64

public func schur(a: Matrix<Float64>): (Matrix<Float64>, Matrix<Float64>) {
    let jobvs: Rune = r'v'
    let sort: Rune = r'n'
    let cjobvs: UInt32 = UInt32(jobvs)
    let csort: UInt32 = UInt32(sort)
    let select: Bool = false
    var sdim  =  vector<Int64>([0])
    let m = a.getRows()
    let n = a.getCols() 
    let lda: Int64 = m
    let ldvs: Int64 = n
    var wr = zeros<Float64>(n)
    var wi = zeros<Float64>(n)
    var vs = zeros<Float64>(ldvs, n)
    let a_copy = a.copy()
    var info = unsafe { LAPACKE_dgees(LAPACK_COL_MAJOR, UInt8(cjobvs), UInt8(csort), select, IntNative(n),
                         a_copy.ptr, IntNative(lda), sdim.ptr, wr.ptr, wi.ptr, vs.ptr, IntNative(ldvs)) }
    return (a_copy, vs)
}

func testDgees() {
    let jobvs: Rune = r'v'
    let sort: Rune = r'n'
    let cjobvs: UInt32 = UInt32(jobvs)
    let csort: UInt32 = UInt32(sort)
    let select: Bool = false
    var sdim  =  vector<Int64>([0])
    var a = matrix<Float64>(
        [[0.0, 2.0, 2.0],
         [0.0, 1.0, 2.0],
         [1.0, 0.0, 1.0]]
    )
    let m = a.getRows()
    let n = a.getCols() 
    let lda: Int64 = m
    let ldvs: Int64 = n
    var wr = zeros<Float64>(n)
    var wi = zeros<Float64>(n)
    var vs = zeros<Float64>(ldvs, n)
    var info = unsafe { LAPACKE_dgees(LAPACK_COL_MAJOR, UInt8(cjobvs), UInt8(csort), select, IntNative(n),
                         a.ptr, IntNative(lda), sdim.ptr, wr.ptr, wi.ptr, vs.ptr, IntNative(ldvs)) }

    var expected_t = matrix<Float64>(
        [[2.658967, 1.424405, -1.929334],
         [0.000000, -0.329484, -0.490637],
         [0.000000, 1.311789, -0.329484]]
    )
    var expected_z = matrix<Float64>(
        [[0.727116, -0.601562, 0.330796],
         [0.528394, 0.798019, 0.289768],
         [0.438294, 0.035904, -0.898114]]
    )
    assertApproxEqual(a, expected_t, atol:1e-6)
    assertApproxEqual(vs, expected_z, atol:1e-6)
}

@Test
public class TestSchur {
    @TestCase
    func testSchur(): Unit {
        var a = matrix<Float64>(
            [[0.0, 2.0, 2.0],
             [0.0, 1.0, 2.0],
             [1.0, 0.0, 1.0]]
        )
        let (t, z) = schur(a)
        var expected_t = matrix<Float64>(
            [[2.658967, 1.424405, -1.929334],
             [0.000000, -0.329484, -0.490637],
             [0.000000, 1.311789, -0.329484]]
        )
        var expected_z = matrix<Float64>(
            [[0.727116, -0.601562, 0.330796],
             [0.528394, 0.798019, 0.289768],
             [0.438294, 0.035904, -0.898114]]
        )
        assertApproxEqual(t, expected_t, atol:1e-6)
        assertApproxEqual(z, expected_z, atol:1e-6)
    }

    @TestCase
    func testLapackDgees(): Unit {
        testDgees()
    }
}