package scientific.fft

import std.unittest.*
import std.unittest.testmacro.*
import std.math.*

import scientific.numbers.Complex64
import scientific.numbers.sqrt
import scientific.linear.empty
import scientific.linear.vector
import scientific.linear.Vector
import scientific.linear.logspace
import scientific.linear.approxEqual
import scientific.linear.approxEqualC
import scientific.linear.Matrix
import scientific.linear.matrix

/* Complex FFT (data has type Float64) */
foreign func make_cfft_plan(length: UIntNative): CPointer<Unit>
foreign func destroy_cfft_plan(plan: CPointer<Unit>): Unit
foreign func cfft_backward(plan: CPointer<Unit>, data: CPointer<Unit>, fct: Float64): Int64
foreign func cfft_forward(plan: CPointer<Unit>, data: CPointer<Unit>, fct: Float64): Int64
foreign func cfft_length(plan: CPointer<Unit>): UIntNative

/* Real FFT (data has type Float64) */
foreign func make_rfft_plan(length: UIntNative): CPointer<Unit>
foreign func destroy_rfft_plan(plan: CPointer<Unit>): Unit
foreign func rfft_backward(plan: CPointer<Unit>, data: CPointer<Unit>, fct: Float64): Int64
foreign func rfft_forward(plan: CPointer<Unit>, data: CPointer<Unit>, fct: Float64): Int64
foreign func rfft_length(plan: CPointer<Unit>): UIntNative


public class CFFTPlan {
    var ptr: CPointer<Unit> = CPointer<Unit>()

    public init(length: Int64) {
        this.ptr = unsafe { make_cfft_plan(UIntNative(length)) }
    }

    public func getLength(): UIntNative {
        return unsafe { cfft_length(this.ptr) }
    }

    public func fftForward(data: Vector<Complex64>, fct: Float64) {
        if (data.size() != Int64(this.getLength())) {
            throw IllegalArgumentException("fftForward: unexpected length")
        }
        unsafe { cfft_forward(this.ptr, data.ptr, fct) }
    }

    public func fftBackward(data: Vector<Complex64>, fct: Float64) {
        if (data.size() != Int64(this.getLength())) {
            throw IllegalArgumentException("fftBackward: unexpected length")
        }
        unsafe { cfft_backward(this.ptr, data.ptr, fct) }
    }

    ~init() {
        unsafe { destroy_cfft_plan(ptr) }
    }
}

public class RFFTPlan {
    var ptr: CPointer<Unit> = CPointer<Unit>()

    public init(length: Int64) {
        this.ptr = unsafe { make_rfft_plan(UIntNative(length)) }
    }

    public func getLength(): UIntNative {
        return unsafe { rfft_length(this.ptr) }
    }

    public func fftForward(data: Vector<Float64>, fct: Float64) {
        if (data.size() != Int64(this.getLength())) {
            throw IllegalArgumentException("fftForward: unexpected length")
        }
        unsafe { rfft_forward(this.ptr, data.ptr, fct) }
    }

    public func fftBackward(data: Vector<Float64>, fct: Float64) {
        if (data.size() != Int64(this.getLength())) {
            throw IllegalArgumentException("fftBackward: unexpected length")
        }
        unsafe { rfft_backward(this.ptr, data.ptr, fct) }
    }

    ~init() {
        unsafe { destroy_rfft_plan(ptr) }
    }
}

public func cfft(data: Vector<Complex64>, norm!:String = "backward"): Vector<Complex64>{
    let n = data.size()
    var fct: Float64
    if (norm == "backward") {
        fct = 1.0
    } else if (norm == "forward") {
        fct = 1.0 / Float64(n)
    } else if (norm == "ortho") {
        fct = 1.0 / sqrt(Float64(n))
    } else {
        throw IllegalArgumentException("cfft: unknown norm")
    }

    let data_copy: Vector<Complex64> = data.copy()
    let plan: CFFTPlan = CFFTPlan(n)
    plan.fftForward(data_copy, fct)
    return data_copy
}

public func icfft(data: Vector<Complex64>, norm!:String = "backward") {
    let n = data.size()
    var fct: Float64
    if (norm == "backward") {
        fct = 1.0 / Float64(n)
    } else if (norm == "forward") {
        fct = 1.0
    } else if (norm == "ortho") {
        fct = 1.0 / sqrt(Float64(n))
    } else {
        throw IllegalArgumentException("icfft: unknown norm")
    }

    let data_copy: Vector<Complex64> = data.copy()
    let plan: CFFTPlan = CFFTPlan(n)
    plan.fftBackward(data_copy, fct)
    return data_copy
}

public func rfft(data: Vector<Float64>, norm!:String = "backward"): Vector<Complex64> {
    let n = data.size()
    var fct: Float64
    if (norm == "backward") {
        fct = 1.0
    } else if (norm == "forward") {
        fct = 1.0 / Float64(n)
    } else if (norm == "ortho") {
        fct = 1.0 / sqrt(Float64(n))
    } else {
        throw IllegalArgumentException("rfft: unknown norm")
    }

    let data_copy: Vector<Float64> = data.copy()
    let plan: RFFTPlan = RFFTPlan(n)
    plan.fftForward(data_copy, fct)

    /* The result has length n // 2 + 1 */
    if (n % 2 == 0) {
        let res = empty<Complex64>(n / 2 + 1)
        res[0] = Complex64(data_copy[0], 0.0)
        for (i in 1..n/2) {
            res[i] = Complex64(data_copy[2*i-1], data_copy[2*i])
        }
        res[n/2] = Complex64(data_copy[n-1], 0.0)
        return res
    } else {
        let res = empty<Complex64>(n / 2 + 1)
        res[0] = Complex64(data_copy[0], 0.0)
        for (i in 1..n/2+1) {
            res[i] = Complex64(data_copy[2*i-1], data_copy[2*i])
        }
        return res
    }
}

public func irfft(data: Vector<Complex64>, norm!:String = "backward"): Vector<Float64> {
    // First, copy input to an array of Float64
    var data_copy: Vector<Float64>
    let n = data.size()
    if (data[n-1].imag == 0.0) {
        data_copy = empty<Float64>(2 * n - 2)
        data_copy[0] = data[0].real
        for (i in 1..n-1) {
            data_copy[2*i-1] = data[i].real
            data_copy[2*i] = data[i].imag
        }
        data_copy[2*n-3] = data[n-1].real
    } else {
        data_copy = empty<Float64>(2 * n - 1)
        data_copy[0] = data[0].real
        for (i in 1..n) {
            data_copy[2*i-1] = data[i].real
            data_copy[2*i] = data[i].imag
        }
    }

    // Let n be the length of data_copy
    let n_copy = data_copy.size()
    var fct: Float64
    if (norm == "backward") {
        fct = 1.0 / Float64(n_copy)
    } else if (norm == "forward") {
        fct = 1.0
    } else if (norm == "ortho") {
        fct = 1.0 / sqrt(Float64(n_copy))
    } else {
        throw IllegalArgumentException("irfft: unknown norm")
    }

    let plan: RFFTPlan = RFFTPlan(n_copy)
    plan.fftBackward(data_copy, fct)
    return data_copy
}

/* 2D FFT */
public func cfft2(data: Matrix<Complex64>, norm!:String = "backward"): Matrix<Complex64> {
    let nr = data.getRows()
    let nc = data.getCols()
    var data_copy: Matrix<Complex64> = data.copy()
    var Xn: Vector<Complex64> 
    var Xk: Vector<Complex64> 

    // cfft for rows
    Xn = empty<Complex64>(nr)
    Xk = empty<Complex64>(nr)
    for (i in 0..nr) {
        Xn = data_copy[i]
        Xk = cfft(Xn)
        data_copy[i] = Xk
    }

    // cfft for columns
    Xn = empty<Complex64>(nc)
    Xk = empty<Complex64>(nc)
    for (i in 0..nc) {
        Xn = data_copy.getCol(i)
        Xk = cfft(Xn)
        data_copy.setCol(i, Xk)
    }
    return data_copy
}

public func icfft2(data: Matrix<Complex64>, norm!:String = "backward") {
    let nr = data.getRows()
    let nc = data.getCols()
    var data_copy: Matrix<Complex64> = data.copy()
    var Xn: Vector<Complex64>
    var Xk: Vector<Complex64>

    // icfft for rows
    Xn = empty<Complex64>(nr)
    for (i in 0..nr) {
        Xn = data_copy[i]
        Xk = icfft(Xn)
        data_copy[i] = Xk
    }

    // icfft for columns
    Xn = empty<Complex64>(nc)
    for (i in 0..nc) {
        Xn = data_copy.getCol(i)
        Xk = icfft(Xn)
        data_copy.setCol(i, Xk)
    }
    return data_copy
}

public func rfft2(data: Matrix<Float64>, norm!:String = "backward"): Matrix<Complex64> {
    let nr = data.getRows()
    let nc = data.getCols()
    var fct: Float64
    let n = nc / 2 + 1
    var data_copy: Matrix<Float64> = data.copy()
    var res: Matrix<Complex64> = empty<Complex64>(nr, n)
    var Xn1: Vector<Float64> = empty<Float64>(nc)
    var Xn2: Vector<Complex64> = empty<Complex64>(nr)
    var Xr: Vector<Complex64> = empty<Complex64>(n)

    // first rfft for rows 
    for (i in 0..nr) {
        Xn1 = data_copy[i]
        Xr = rfft(Xn1)
        res[i] = Xr
    }

    // second rfft for columns
    for (i in 0..n) {
        Xn2 = res.getCol(i)
        Xr = cfft(Xn2)
        res.setCol(i, Xr)
    }
    return res
}

public func irfft2(data: Matrix<Complex64>, norm!:String = "backward"): Matrix<Float64> {
    let nr = data.getRows()
    let nc = data.getCols()
    var fct: Float64
    var n: Int64
    if (nc % 2 == 0) {
        n = 2 * nc - 2
    } else {
        n = 2 * nc - 1
    }

    var data_copy: Matrix<Complex64> = data.copy()
    var res: Matrix<Float64> = empty<Float64>(nr, n)
    var X1: Vector<Complex64> = empty<Complex64>(nr)
    var X1temp: Vector<Complex64> = empty<Complex64>(nr)
    var X2: Vector<Float64>  = empty<Float64>(n)

    // first irfft for columns
    for (i in 0..nc) {
        X1 = data_copy.getCol(i)
        X1temp = icfft(X1)
        data_copy.setCol(i, X1temp)
    }

    // second irfft for rows
    for (i in 0..nr) {
        X1 = data_copy[i]
        X2 = irfft(X1)
        res[i] = X2
    }
    return res
}


@Test
public class TestFFT {
    @TestCase
    func testFFT1(): Unit {
        let a: Vector<Complex64> = vector<Complex64>([
            Complex64(1.0, 0.0), Complex64(0.0, 1.0), Complex64(-1.0, 0.0), Complex64(0.0, -1.0)
        ])
        let b = cfft(a)
        let c = icfft(b)
        let expected_b = vector<Complex64>([
            Complex64(0.0, 0.0), Complex64(4.0, 0.0), Complex64(0.0, 0.0), Complex64(0.0, 0.0)
        ])
        @Assert(approxEqualC(b, expected_b))
        @Assert(approxEqualC(a, c))
    }

    @TestCase
    func testFFT2(): Unit {
        let a = vector<Float64>([1.0, 2.0, 3.0, 4.0, 3.0, 2.0, 1.0, 2.0])
        let b = rfft(a)
        let c = irfft(b)
        @Assert(approxEqual(a, c))
    }

    @TestCase
    func testFFT3(): Unit {
        let a = vector<Float64>([1.0, 2.0, 3.0, 4.0, 3.0, 2.0, 1.0])
        let b = rfft(a)
        let c = irfft(b)
        @Assert(approxEqual(a, c))
    }

    @TestCase
    func testFFT4(): Unit {
        let a = matrix<Complex64>([
            [Complex64(1.0, 2.0), Complex64(2.0, 1.0)],
            [Complex64(1.0, 1.0), Complex64(2.0, 2.0)]
        ])
        let b = cfft2(a)
        let expected_b = matrix<Complex64>([
            [Complex64(6.0, 6.0), Complex64(-2.0, 0.0)],
            [Complex64(0.0, 0.0), Complex64(0.0, 2.0)]
        ])
        let c = icfft2(b)
        @Assert(approxEqualC(b, expected_b))
        @Assert(approxEqualC(a, c))
    }

    @TestCase
    func testFFT5(): Unit {
        let a = matrix<Float64>(
            [[ 0.0,  1.0,  2.0,  3.0,  4.0],
             [10.0, 11.0, 12.0, 13.0, 14.0],
             [20.0, 21.0, 22.0, 23.0, 24.0],
             [30.0, 31.0, 32.0, 33.0, 34.0],
             [40.0, 41.0, 42.0, 43.0, 44.0]]
        )
        let b = rfft2(a)
        let c = irfft2(b)
        @Assert(approxEqual(a, c))
    }

    @TestCase
    func testFFT6(): Unit {
        let a = matrix<Float64>(
            [[ 0.0,  1.0,  2.0,  3.0,  4.0, 5.0],
             [10.0, 11.0, 12.0, 13.0, 14.0, 15.0],
             [20.0, 21.0, 22.0, 23.0, 24.0, 25.0],
             [30.0, 31.0, 32.0, 33.0, 34.0, 35.0],
             [40.0, 41.0, 42.0, 43.0, 44.0, 45.0]]
        )
        let b = rfft2(a)
        let c = irfft2(b)
        @Assert(approxEqual(a, c))
    }
}