package scientific.fft

import std.math.*
import std.unittest.*
import std.unittest.testmacro.*

import scientific.numbers.*
import scientific.linear.*

// Digital filter frequency response
public func freqz(b: Vector<Float64>, a: Vector<Float64>, worN!:Int64 = 512,
                  whole!:Bool = false, fs!:Float64 = 2.0 * Float64.getPI()):
        (Vector<Float64>, Vector<Complex64>) {
    let M = b.size()
    let N = a.size()
    var f_cal: Float64
    var H: Vector<Complex64> = empty(worN)
    var W: Vector<Float64> = empty(worN)
    if (whole) {
        f_cal = fs
    } else {
        f_cal = fs / 2.0
    }
    
    let w: Float64 = f_cal / Float64(worN)
    var h: Complex64
    var b_temp: Complex64
    var a_temp: Complex64
    var w_i: Float64
    var aa: Float64
    var a_temp_con: Complex64
    for (i in 0..worN) {
        w_i = w * Float64(i)
        b_temp = Complex64(0.0, 0.0)
        a_temp = Complex64(0.0, 0.0)
        for (k in 0..N) {
            a_temp += Complex64(a[k] * cos(Float64(k) * w_i), -a[k] * sin(Float64(k) * w_i))
        }
        aa = 1.0 / a_temp.normsq()
        a_temp_con = a_temp.conjugate()
        for (k in 0..M) {
            b_temp += Complex64(aa * b[k] * cos(Float64(k) * w_i), -aa * b[k] * sin(Float64(k) * w_i))
        }
        h = b_temp * a_temp.conjugate()
        W[i] = w_i
        H[i] = h
    }
    return (W, H)
}

//Analog filter frequency response
public func freqs(b: Vector<Float64>, a: Vector<Float64>, worN!:Int64 = 200): (Vector<Float64>, Vector<Complex64>){
    var W: Vector<Float64> = logspace<Float64>(-1.0, 1.0, num: worN)
    var Wr: Vector<Float64> = empty(worN)
    let M = b.size()
    let N = a.size()
    var H: Vector<Complex64> = empty(worN)
    var h: Complex64
    var b_temp: Complex64
    var a_temp: Complex64
    var b_temp_pow: Complex64
    var a_temp_pow: Complex64
    var aa: Float64
    var a_temp_con: Complex64 
    var w_i: Float64
    for (i in 0..W.size()) {
        w_i = W[i]
        b_temp = Complex64(0.0, 0.0)
        a_temp = Complex64(0.0, 0.0)
        for (k in 0..N - 1) {
            a_temp_pow = Complex64(0.0, a[k] * w_i)
            for (j in 0..N-k-2) {
                a_temp_pow *= Complex64(0.0, w_i)
            }
            a_temp += a_temp_pow
        }
        a_temp += Complex64(a[N-1], 0.0)
        aa = 1.0 / a_temp.normsq()
        a_temp_con = a_temp.conjugate()
        for (k in 0..M - 1) {
            b_temp_pow = Complex64(0.0, aa * b[k] * w_i)
            for (j in 0..M-k-2) {
                b_temp_pow *= Complex64(0.0, w_i)
            }
            b_temp += b_temp_pow
        }
        b_temp += Complex64(aa * b[M-1], 0.0)
        h = b_temp * a_temp.conjugate()
        Wr[i] = w_i
        H[i] = h
    }
    return (Wr, H)
}

@Test
public class testFreq {
    // Test freqz
    @TestCase
    func testFreq1(): Unit {
        var W: Vector<Float64>
        var H: Vector<Complex64>
        var res: (Vector<Float64>, Vector<Complex64>) 
        let b: Vector<Float64> = vector<Float64>([10.0, -5.0, 3.5])
        let a: Vector<Float64> = vector<Float64>([1.2, 2.1])
        res = freqz(b, a) 
        W = res[0]
        H = res[1]

        //490 to 495
        let part_expected_W = vector<Float64>([3.00660234, 3.01273827, 3.01887419, 3.02501011, 3.03114604])
        let part_expected_H = vector<Complex64>([
            Complex64(-19.39452066, 4.39878641), Complex64(-19.49254083, 4.22105667), Complex64(-19.58689081, 4.040422), 
            Complex64(-19.67745504, 3.8569767), Complex64(-19.76412097, 3.67082077)
            ])
        var part_W: Vector<Float64> = W[490..495 : 1]
        var part_H: Vector<Complex64> = H[490..495 : 1]
        @Assert(approxEqual(part_expected_W, part_W))
        @Assert(approxEqualC(part_expected_H, part_H, atol:1e-6))
    }
    
    //Test freqs
    @TestCase
    func testFreq2(): Unit {
        var W: Vector<Float64>
        var H: Vector<Complex64>
        var res: (Vector<Float64>, Vector<Complex64>) 
        let b: Vector<Float64> = vector<Float64>([2.0, -3.0, 1.5])
        let a: Vector<Float64> = vector<Float64>([1.0, 2.5, 3.1, 4.7])
        res = freqs(b, a)
        W = res[0]
        H = res[1]
        
        //20 to 25
        let part_expected_W: Vector<Float64> = vector<Float64>([0.15885651, 0.16257557, 0.16638169, 0.17027692, 0.17426334])
        let part_expected_H: Vector<Complex64> = vector<Complex64>([
            Complex64(0.2984683, -0.13421764), Complex64(0.29748368, -0.13732975), Complex64(0.29645192, -0.14051256), 
            Complex64(0.29537075, -0.14376756), Complex64(0.29423777, -0.14709626)
            ])
        var part_W: Vector<Float64> = W[20..25 : 1]
        var part_H: Vector<Complex64> = H[20..25 : 1]
        @Assert(approxEqual(part_expected_W, part_W))
        @Assert(approxEqualC(part_expected_H, part_H, atol:1e-6))
    }
}