package scientific.linear

import std.unittest.*
import std.unittest.testmacro.*

import scientific.numbers.*

public func orthogonalProcrustes(a: Matrix<Float64>, b: Matrix<Float64>): (Matrix<Float64>, Float64) {
    var (u, w, vt) = svd(a.transpose() * b)
    let R = u * vt
    let scale = sum(w)
    return (R, scale)
}

public func bandwidth(a: Matrix<Float64>): (Int64, Int64) {
    let row = a.getRows()
    let col = a.getCols()
    var lb: Int64 = 0
    var ub: Int64 = 0
    for (i in 0..row) {
        for (j in 0..col) {
            if (a[i,j] != 0.0) {
                if (i > j && i - j > lb) {
                    lb = i - j
                }
                if (i < j && j - i > ub) {
                    ub = j - i
                }
            }
        }
    }
    return (lb, ub)
}

public func issymmetric(a: Matrix<Float64>, atol!:Float64 = 0.0): Bool {
    if (!a.isSquare()) {
        return false
    }
    let n = a.getRows()
    for (i in 0..n) {
        for (j in i+1..n) {
            if (abs(a[i,j] - a[j,i]) > atol) {
                return false
            }
        }
    }
    return true
}

public func ishermitian(a: Matrix<Complex64>, atol!:Float64 = 0.0): Bool {
    if (!a.isSquare()) {
        return false
    }
    let n = a.getRows()
    for (i in 0..n) {
        if (abs(a[i,i].imag) > atol) {
            return false
        }
        for (j in i+1..n) {
            if (abs(a[i,j].real - a[j,i].real) > atol ||
                abs(a[i,j].imag + a[j,i].imag) > atol) {
                return false
            }
        }
    }
    return true
}

@Test
public class TestLinearBasic {
    @TestCase
    func testProcrustes(): Unit {
        var a = matrix<Float64>(
            [[ 2.0, 0.0, 1.0],
             [-2.0, 0.0, 0.0]]
        )
        var b = matrix<Float64>(
            [[1.0, 0.0,  2.0],
             [0.0, 0.0, -2.0]]
        )
        var (R, sca) = orthogonalProcrustes(a, b)
        var expected_R = matrix<Float64>(
            [[0.0, 0.0, 1.0],
             [0.0, 1.0, 0.0],
             [1.0, 0.0, 0.0]]
        )
        var expected_sca: Float64 = 9.0
        @Assert(approxEqual(R, expected_R, atol:1e-6))
        @Assert(sca == expected_sca)
    }

    @TestCase
    func testBandwidth(): Unit {
        var a = matrix<Float64>(
            [[3.0, 0.0, 0.0, 0.0, 0.0],
             [0.0, 4.0, 0.0, 0.0, 0.0],
             [0.0, 0.0, 5.0, 1.0, 0.0],
             [8.0, 0.0, 0.0, 6.0, 2.0],
             [0.0, 9.0, 0.0, 0.0, 7.0]]
        )
        var (lb, ub) = bandwidth(a)
        @Assert(lb == 3)
        @Assert(ub == 1)
    }

    @TestCase
    func testSymmetric(): Unit {
        var a = matrix<Float64>(
            [[0.0, 0.0, 1.0],
             [0.0, 1.0, 0.0],
             [1.0, 0.0, 0.0]]
        )
        @Assert(issymmetric(a))
    }

    @TestCase
    func testHermitian(): Unit {
        var A = matrix<Complex64>(
            [[Complex64(1.0, 0.0), Complex64(2.0, 3.0)],
             [Complex64(2.0, -3.0), Complex64(4.0, 0.0)]]
        )
        @Assert(ishermitian(A))

        var B = matrix<Complex64>(
            [[Complex64(1.0, 1.0), Complex64(0.0, 3.0)],
             [Complex64(0.0, 3.0), Complex64(2.0, 0.0)]]
        )
        @Assert(!ishermitian(B))
    }
}