// CBLAS Level 3 functions

package scientific.linear

import std.unittest.*
import std.unittest.testmacro.*

import scientific.numbers.*

/* Compute the product of two matrices alpha * A * B and add the result
   to beta * C. Either of A or B may be taken transpose.
*/
foreign func cblas_dgemm(
    matrix_layout: Int64,   // row or column major
    transpose_a: Int64,     // whether A is taken transpose ('n' or 't')
    transpose_b: Int64,     // whether B is taken transpose ('n' or 't')
    m: Int64,               // number of rows of A (or transpose(A))
    n: Int64,               // number of columns of B (or transpose(B))
    k: Int64,               // number of columns of A (or transpose(A))
    alpha: Float64,         // scalar factor alpha
    a: CPointer<Unit>,      // matrix A (Float64)
    lda: Int64,             // leading dimension of A
    b: CPointer<Unit>,      // matrix B (Float64)
    ldb: Int64,             // leading dimension of B
    beta: Float64,          // scalar factor beta
    c: CPointer<Unit>,      // matrix C (Float64)
    ldc: Int64              // leading dimension of C
): Unit

/* Float32 version */
foreign func cblas_sgemm(
    matrix_layout: Int64,   // row or column major
    transpose_a: Int64,     // whether A is taken transpose ('n' or 't')
    transpose_b: Int64,     // whether B is taken transpose ('n' or 't')
    m: Int64,               // number of rows of A (or transpose(A))
    n: Int64,               // number of columns of B (or transpose(B))
    k: Int64,               // number of columns of A (or transpose(A))
    alpha: Float32,         // scalar factor alpha
    a: CPointer<Unit>,      // matrix A (Float32)
    lda: Int64,             // leading dimension of A
    b: CPointer<Unit>,      // matrix B (Float32)
    ldb: Int64,             // leading dimension of B
    beta: Float32,          // scalar factor beta
    c: CPointer<Unit>,      // matrix C (Float32)
    ldc: Int64              // leading dimension of C
): Unit

/* Complex64 version */
foreign func cblas_zgemm(
    matrix_layout: Int64,   // row or column major
    transpose_a: Int64,     // whether A is taken transpose ('n' or 't')
    transpose_b: Int64,     // whether B is taken transpose ('n' or 't')
    m: Int64,               // number of rows of A (or transpose(A))
    n: Int64,               // number of columns of B (or transpose(B))
    k: Int64,               // number of columns of A (or transpose(A))
    alpha: CPointer<Unit>,  // scalar factor alpha
    a: CPointer<Unit>,      // matrix A (Complex64)
    lda: Int64,             // leading dimension of A
    b: CPointer<Unit>,      // matrix B (Complex64)
    ldb: Int64,             // leading dimension of B
    beta: CPointer<Unit>,   // scalar factor beta
    c: CPointer<Unit>,      // matrix C (Complex64)
    ldc: Int64              // leading dimension of C
): Unit

/* Complex32 version */
foreign func cblas_cgemm(
    matrix_layout: Int64,   // row or column major
    transpose_a: Int64,     // whether A is taken transpose ('n' or 't')
    transpose_b: Int64,     // whether B is taken transpose ('n' or 't')
    m: Int64,               // number of rows of A (or transpose(A))
    n: Int64,               // number of columns of B (or transpose(B))
    k: Int64,               // number of columns of A (or transpose(A))
    alpha: CPointer<Unit>,  // scalar factor alpha
    a: CPointer<Unit>,      // matrix A (Complex32)
    lda: Int64,             // leading dimension of A
    b: CPointer<Unit>,      // matrix B (Complex32)
    ldb: Int64,             // leading dimension of B
    beta: CPointer<Unit>,   // scalar factor beta
    c: CPointer<Unit>,      // matrix C (Complex32)
    ldc: Int64              // leading dimension of C
): Unit

public func cj_dgemm(a: Matrix<Float64>, b: Matrix<Float64>): Matrix<Float64> {
    if (a.getCols() != b.getRows()) {
        throw IllegalArgumentException("Matrix multiply: dimension mismatch")
    }

    let CblasColMajor: Int64 = 102
    let transa: Int64 = 111
    let transb: Int64 = 111
    let m = a.getRows()
    let n = b.getCols()
    let k = a.getCols()
    let alpha = 1.0
    let beta = 0.0
    let lda = m
    let ldb = k
    let ldc = m
    let c = empty<Float64>(m, n)
    unsafe {cblas_dgemm(CblasColMajor, transa, transb, m, n, k, alpha,
                        a.ptr, lda, b.ptr, ldb, beta, c.ptr, ldc)}
    return c
}

public func cj_sgemm(a: Matrix<Float32>, b: Matrix<Float32>): Matrix<Float32> {
    if (a.getCols() != b.getRows()) {
        throw IllegalArgumentException("Matrix multiply: dimension mismatch")
    }

    let CblasColMajor: Int64 = 102
    let transa: Int64 = 111
    let transb: Int64 = 111
    let m = a.getRows()
    let n = b.getCols()
    let k = a.getCols()
    let alpha = Float32(1.0)
    let beta = Float32(0.0)
    let lda = m
    let ldb = k
    let ldc = m
    let c = empty<Float32>(m, n)
    unsafe {cblas_sgemm(CblasColMajor, transa, transb, m, n, k, alpha,
                        a.ptr, lda, b.ptr, ldb, beta, c.ptr, ldc)}
    return c
}

public func cj_zgemm(a: Matrix<Complex64>, b: Matrix<Complex64>): Matrix<Complex64> {
    if (a.getCols() != b.getRows()) {
        throw IllegalArgumentException("Matrix multiply: dimension mismatch")
    }

    let CblasColMajor: Int64 = 102
    let transa: Int64 = 111
    let transb: Int64 = 111
    let m = a.getRows()
    let n = b.getCols()
    let k = a.getCols()
    var alpha: CPointer<Unit> = unsafe {malloc(UIntNative(16))}
    var beta: CPointer<Unit> = unsafe {malloc(UIntNative(16))}
    Complex64.write(alpha, Complex64(1.0, 0.0), 0)
    Complex64.write(beta, Complex64(0.0, 0.0), 0)
    let lda = m
    let ldb = k
    let ldc = m
    let c = empty<Complex64>(m, n)
    unsafe {cblas_zgemm(CblasColMajor, transa, transb, m, n, k, alpha,
                        a.ptr, lda, b.ptr, ldb, beta, c.ptr, ldc)}
    unsafe { free(alpha) }
    unsafe { free(beta) }
    return c
}

public func cj_cgemm(a: Matrix<Complex32>, b: Matrix<Complex32>): Matrix<Complex32> {
    if (a.getCols() != b.getRows()) {
        throw IllegalArgumentException("Matrix multiply: dimension mismatch")
    }

    let CblasColMajor: Int64 = 102
    let transa: Int64 = 111
    let transb: Int64 = 111
    let m = a.getRows()
    let n = b.getCols()
    let k = a.getCols()
    var alpha: CPointer<Unit> = unsafe {malloc(UIntNative(8))}
    var beta: CPointer<Unit> = unsafe {malloc(UIntNative(8))}
    Complex32.write(alpha, Complex32(1.0, 0.0), 0)
    Complex32.write(beta, Complex32(0.0, 0.0), 0)
    let lda = m
    let ldb = k
    let ldc = m
    let c = empty<Complex32>(m, n)
    unsafe {cblas_cgemm(CblasColMajor, transa, transb, m, n, k, alpha,
                        a.ptr, lda, b.ptr, ldb, beta, c.ptr, ldc)}
    unsafe { free(alpha) }
    unsafe { free(beta) }
    return c
}

func testOriginalDgemm() {
    let CblasColMajor: Int64 = 102
    let transa: Int64 = 111
    let transb: Int64 = 111
    let m: Int64 = 3
    let n: Int64 = 5
    let k: Int64 = 4
    let alpha: Float64 = 1.0
    let beta: Float64 = 0.0
    let lda: Int64 = 3
    let ldb: Int64 = 4
    let ldc: Int64 = 3

    var a = matrix<Float64>(
        [[1.0, 2.0, 4.0, 6.0],
         [4.0, 6.0, 8.0, 10.0],
         [7.0, 9.0, 11.0, 13.0]]
    )
    var b = matrix<Float64>(
        [[1.0, 2.0, 3.0, 10.0, 11.0],
         [4.0, 5.0, 6.0, 12.0, 13.0],
         [7.0, 8.0, 9.0, 14.0, 15.0],
         [10.0, 11.0, 12.0, 16.0, 17.0]]
    )
    var c = empty<Float64>(3, 5)

    unsafe {
        cblas_dgemm(CblasColMajor, transa, transb, m, n, k, alpha,
            a.ptr, lda, b.ptr, ldb, beta ,c.ptr, ldc)
    }

    let expected_result = matrix<Float64>(
        [[ 97.0, 110.0, 123.0, 186.0, 199.0],
         [184.0, 212.0, 240.0, 384.0, 412.0],
         [250.0, 290.0, 330.0, 540.0, 580.0]]
    )
    assertApproxEqual(c, expected_result, atol:1e-6)
}

@Test
public class TestBlasL3 {
    @TestCase
    func testBlasDgemm(): Unit {
        testOriginalDgemm()
    }

    @TestCase
    func test_dgemm(): Unit {
        var a = matrix<Float64>(
            [[1.0, 2.0, 4.0, 6.0],
             [4.0, 6.0, 8.0, 10.0],
             [7.0, 9.0, 11.0, 13.0]]
        )
        var b = matrix<Float64>(
            [[1.0, 2.0, 3.0, 10.0, 11.0],
             [4.0, 5.0, 6.0, 12.0, 13.0],
             [7.0, 8.0, 9.0, 14.0, 15.0],
             [10.0, 11.0, 12.0, 16.0, 17.0]]
        )
        var c = cj_dgemm(a, b)
        let expected_result = matrix<Float64>(
            [[ 97.0, 110.0, 123.0, 186.0, 199.0],
             [184.0, 212.0, 240.0, 384.0, 412.0],
             [250.0, 290.0, 330.0, 540.0, 580.0]]
        )
        @Assert(approxEqual(c, expected_result, atol:1e-6))
    }

    @TestCase
    func test_sgemm(): Unit {
        var a = matrix<Float32>(
            [[1.0, 2.0, 4.0, 6.0],
             [4.0, 6.0, 8.0, 10.0],
             [7.0, 9.0, 11.0, 13.0]]
        )
        var b = matrix<Float32>(
            [[1.0, 2.0, 3.0, 10.0, 11.0],
             [4.0, 5.0, 6.0, 12.0, 13.0],
             [7.0, 8.0, 9.0, 14.0, 15.0],
             [10.0, 11.0, 12.0, 16.0, 17.0]]
        )
        var c = cj_sgemm(a, b)
        let expected_result = matrix<Float32>(
            [[ 97.0, 110.0, 123.0, 186.0, 199.0],
             [184.0, 212.0, 240.0, 384.0, 412.0],
             [250.0, 290.0, 330.0, 540.0, 580.0]]
        )
        @Assert(approxEqual(c, expected_result, atol:1e-3))
    }

    @TestCase
    func test_zgemm(): Unit {
        var a = matrix<Complex64>(
            [[Complex64(1.0, 1.0), Complex64(0.0, 3.0)],
             [Complex64(0.0, 3.0), Complex64(2.0, 0.0)]]
        )
        var b = matrix<Complex64>(
            [[Complex64(1.0, 0.0), Complex64(2.0, 3.0)],
             [Complex64(2.0, -3.0), Complex64(4.0, 0.0)]]
        )
        var c = cj_zgemm(a, b)
        let expected_c = matrix<Complex64>(
            [[Complex64(10.0, 7.0), Complex64(-1.0, 17.0)],
             [Complex64(4.0, -3.0), Complex64(-1.0, 6.0)]]
        )
        @Assert(approxEqualC(c, expected_c))
    }

    @TestCase
    func test_cgemm(): Unit {
        var a = matrix<Complex32>(
            [[Complex32(1.0, 1.0), Complex32(0.0, 3.0)],
            [Complex32(0.0, 3.0), Complex32(2.0, 0.0)]]
        )
        var b = matrix<Complex32>(
            [[Complex32(1.0, 0.0), Complex32(2.0, 3.0)],
            [Complex32(2.0, -3.0), Complex32(4.0, 0.0)]]
        )
        var c = cj_cgemm(a, b)
        let expected_c = matrix<Complex32>(
            [[Complex32(10.0, 7.0), Complex32(-1.0, 17.0)],
            [Complex32(4.0, -3.0), Complex32(-1.0, 6.0)]]
        )
        @Assert(approxEqualC(c, expected_c))
    }
}