package scientific.ode

import std.unittest.*
import std.unittest.testmacro.*

import std.math.*
import std.collection.*
import std.time.*

import scientific.matplot.*
import scientific.linear.*

let MIN_FACTOR: Float64 = 0.2
let MAX_FACTOR: Float64 = 10.0
let SAFETY: Float64 = 0.9
let beta: Float64 = 0.0
let alpha: Float64 = 0.2 - beta * 0.75

// Some test functions
public func Fun1(x: Float64, y: Array<Float64>, j: Int64): Float64 {
    if (j == 1) {
        return 2.0 * y[1] - y[0]
    } else {
        return y[1]
    }
}

public func Fun2(x: Float64, y: Array<Float64>, j: Int64): Float64 {
    if (j == 2) {
        return 2.0 * y[2] - y[1] + 2.0 * y[0]
    }
    if (j == 1) {
        return y[2]
    } else {
        return y[1]
    }
}

public func Fun3(x: Float64, y: Array<Float64>, j: Int64): Float64 {
    return y[0] * y[0] * cos(x + y[0])
}


public func Fun5(x: Float64, y: Array<Float64>, j: Int64): Float64 {
    if (j == 1) {
        return x * x - y[1]
    } else {
        return y[1]
    }
}

public func Fun6(x: Float64, y: Array<Float64>, j: Int64): Float64 {
    return y[0] * y[0] * (cos(x) - sin(x)) - y[0]
}

public func BM_Func1(x: Float64, y: Array<Float64>, j: Int64): Float64 {
    if (j == 1) {
        return -3.0 * y[1] + 1.0 * y[0] * y[1]
    }
    else {
        return 1.5 * y[0] - 1.0 * y[0] * y[1]
    }
    
}

public func BM_Func2(x: Float64, y: Array<Float64>, j: Int64): Float64 {
    if (j == 2) {
        return -0.5 * y[0] * y[1] + 0.25 * sin(x) * sin(x)
    }
    if (j == 1) {
        return 1.25 * y[0] * y[2]
    } else {
        return -2.0 * y[1] * y[2]
    }
}

public func BM_Func3(x: Float64, y: Array<Float64>, j: Int64): Float64 {
    if (j == 2) {
        return 1e4 * y[2] * y[2]
    }
    if (j == 1) {
        return y[2]
    } else {
        return y[1]
    }
}

open class OdeSolver {
    var fun: (Float64, Array<Float64>, Int64) -> Float64
    var t_bound: Float64
    var n: Int64
    var direction: Float64
    var t: Float64
    var y: Array<Float64>
    var step_size: Float64

    init(fun: (Float64, Array<Float64>, Int64) -> Float64, t0: Float64, y0: Array<Float64>, t_bound: Float64) {
        this.fun = fun
        this.t = t0
        this.t_bound = t_bound
        this.n = y0.size
        if (t_bound > t0) {
            this.direction = 1.0
        } else {
            this.direction = -1.0
        }
        this.y = y0
        this.step_size = 0.01
    }
}

open class RungeKutta <: OdeSolver {
    var A: Matrix<Float64> = empty(0, 0)
    var B: Vector<Float64> = empty(0)
    var C: Vector<Float64> = empty(0)
    var E: Vector<Float64> = empty(0)
    var P: Matrix<Float64> = empty(0, 0)
    var n_stages: Int64
    var y_new: Array<Float64>
    var max_step: Float64 = 10.0
    var min_step: Float64 = 1e-15
    var atol: Float64
    var rtol: Float64
    var h: Float64
    var error_exponent: Float64
    var error_estimator_order: Int64
    var K: Matrix<Float64> = empty(0, 0)

    init(fun: (Float64, Array<Float64>, Int64) -> Float64, t0: Float64, y0: Array<Float64>,
         t_bound: Float64, atol: Float64, rtol: Float64, error_estimator_order: Int64, n_stages: Int64) {
        super(fun, t0, y0, t_bound)
        this.rtol = rtol
        this.atol = atol
        this.y_new = y0[..y0.size]
        this.n_stages = n_stages
        this.h = this.step_size
        this.K = empty(n_stages, n)
        this.error_estimator_order = error_estimator_order
        this.error_exponent = -1.0 / Float64(error_estimator_order + 1)
    }

    public func step() {
        h = min(h, max_step)
        h = max(h, min_step)
        h = min(h, abs(t_bound - t))
        var step_accepted: Bool = false
        var step_rejected: Bool = false
        var h_try = h
        var dxdy = Array<Float64>(n, { i => Float64(0) })
        for (i in 0..n) {
            dxdy[i] = fun(t, y, i)
        }
        var ytemp = Array<Float64>(n, { i => Float64(0) })
        var y_err = Array<Float64>(n, { i => Float64(0) })
        while (step_accepted == false) {
            if (h_try < min_step) {
                throw IllegalArgumentException("Required step size is less than spacing between numbers.")
            }
            for (j in 1..n_stages) {
                for (i in 0..n) {
                    ytemp[i] = y[i] + h_try * A[j][0] * dxdy[i] 
                    for (k in 1..j) {
                        ytemp[i] += h_try * A[j][k] * K[j - 1, i]
                    }
                }
                for (i in 0..n) {
                    K[j, i] = fun(t + C[j] * h_try, ytemp, i)
                }
            }
            for (i in 0..n) {
                y_new[i] = y[i] + h_try * B[0] * dxdy[i]
                for (j in 1..n_stages) {
                    y_new[i] += h_try * B[j] * K[j, i] 
                }
            }
            for (i in 0..n) {
                y_err[i] = h_try * (E[0] * dxdy[i] + E[n_stages] * fun(t + h_try, y_new, i))
                for (j in 1..n_stages) {
                    y_err[i] += h_try * E[j] * K[j, i]
                }
            }
            var err: Float64 = 0.0
            var scale: Float64 = 0.0
            var diff: Float64 = 0.0
            for (i in 0..n) {
                scale = atol + max(abs(y[i]), abs(y_new[i])) * rtol
                diff = y_err[i]
                err += (diff / scale) * (diff / scale)
            }
            err = sqrt(err / Float64(n))
            var factor: Float64 = 0.0
            if (err < 1.0) {
                if (err == 0.0) {
                    factor = MAX_FACTOR
                } else {
                    factor = min(MAX_FACTOR, SAFETY * pow(err, error_exponent))
                }
                if (step_rejected == true) {
                    factor = min(1.0, factor)
                }
                t = t + h_try
                h_try = h_try * factor
                step_accepted = true
            } else {
                h_try =h_try * max(MIN_FACTOR, SAFETY * pow(err, error_exponent))
                step_rejected = true
            } 
        }
        h = h_try
        y = y_new[..y_new.size]
    }
}

class RK23 <: RungeKutta {
    init(fun: (Float64, Array<Float64>, Int64) -> Float64, t0: Float64, y0: Array<Float64>,
         t_bound: Float64, atol: Float64, rtol: Float64) {
      super(fun, t0, y0, t_bound, atol, rtol, 2, 3)
      this.C = vector([0.0, 1.0/2.0, 3.0/4.0])
      this.A = matrix([[0.0, 0.0, 0.0],
                       [1.0/2.0, 0.0, 0.0],
                       [0.0, 3.0/4.0, 0.0]])
      this.B = vector([2.0/9.0, 1.0/3.0, 4.0/9.0])
      this.E = vector([5.0/72.0, -1.0/12.0, -1.0/9.0, 1.0/8.0])
      this.P = matrix([[1.0, -4.0 / 3.0, 5.0 / 9.0],
                       [0.0, 1.0, -2.0/3.0],
                       [0.0, 4.0/3.0, -8.0/9.0],
                       [0.0, -1.0, 1.0]])
    }
}

class RK45 <: RungeKutta {
    init(fun: (Float64, Array<Float64>, Int64) -> Float64, t0: Float64, y0: Array<Float64>,
         t_bound: Float64, atol: Float64, rtol: Float64) {
        super(fun, t0, y0, t_bound, atol, rtol, 4, 6)
        this.C = vector(
            [0.0, 1.0/5.0, 3.0/10.0, 4.0/5.0, 8.0/9.0, 1.0]
        )
        this.A = matrix([
            [0.0, 0.0, 0.0, 0.0, 0.0],
            [1.0/5.0, 0.0, 0.0, 0.0, 0.0],
            [3.0/40.0, 9.0/40.0, 0.0, 0.0, 0.0],
            [44.0/45.0, -56.0/15.0, 32.0/9.0, 0.0, 0.0],
            [19372.0/6561.0, -25360.0/2187.0, 64448.0/6561.0, -212.0/729.0, 0.0],
            [9017.0/3168.0, -355.0/33.0, 46732.0/5247.0, 49.0/176.0, -5103.0/18656.0]
        ])
        this.B = vector(
            [35.0/384.0, 0.0, 500.0/1113.0, 125.0/192.0, -2187.0/6784.0, 11.0/84.0]
        )
        this.E = vector(
            [-71.0/57600.0, 0.0, 71.0/16695.0, -71.0/1920.0, 17253.0/339200.0, -22.0/525.0, 1.0/40.0]
        )
        this.P = matrix([
            [1.0, -8048581381.0/2820520608.0, 8663915743.0/2820520608.0, -12715105075.0/11282082432.0],
            [0.0, 0.0, 0.0, 0.0],
            [0.0, 131558114200.0/32700410799.0, -68118460800.0/10900136933.0, 87487479700.0/32700410799.0],
            [0.0, -1754552775.0/470086768.0, 14199869525.0/1410260304.0, -10690763975.0/1880347072.0],
            [0.0, 127303824393.0/49829197408.0, -318862633887.0/49829197408.0, 701980252875.0 / 199316789632.0],
            [0.0, -282668133.0/205662961.0, 2019193451.0/616988883.0, -1453857185.0/822651844.0],
            [0.0, 40617522.0/29380423.0, -110615467.0/29380423.0, 69997945.0/29380423.0]
        ])
    }
}

public func solveIvp(fun: (Float64, Array<Float64>, Int64) -> Float64, t0: Float64,
                     t_bound: Float64, y0: Array<Float64>, method: String,
                     rtol: Float64, atol: Float64): (Vector<Float64>, Array<Vector<Float64>>) {
    var solver: RungeKutta
    if (method == "RK45") {
        solver = RK45(fun, t0, y0, t_bound, atol, rtol)
    } else if (method == "RK23") {
        solver = RK23(fun, t0, y0, t_bound, atol, rtol)
    } else {
        throw IllegalArgumentException("solveIvp: no such method")
    }

    let dim = y0.size
    var list_x = ArrayList<Float64>()
    let list_y = Array<ArrayList<Float64>>(dim, {_ => ArrayList<Float64>()})
    for (i in 0..dim) {
        list_y[i] = ArrayList<Float64>()
    }

    list_x.add(t0)
    for (i in 0..dim) {
        list_y[i].add(y0[i])            
    }
    var num = 0
    while (solver.t != t_bound) {
        solver.step()
        list_x.add(solver.t)
        for (i in 0..dim) {
            list_y[i].add(solver.y[i])
        }
        num = num + 1
    }

    let vec_x = vector(list_x.toArray())
    let vec_y = Array<Vector<Float64>>(dim, {_ => empty<Float64>(0)})
    for (i in 0..dim) {
        vec_y[i] = vector<Float64>(list_y[i].toArray())
    }
    return (vec_x, vec_y)
}

@Test
public class TestRungeKutta {
    // @TestCase
    func testOde1(): Unit {
        var (vec_x, vec_y) = solveIvp(Fun3, 0.0, 100.0, [0.2], "RK45", 1e-3, 1e-6)
        plot(vec_x, vec_y[0])
        xlabel("t")
        ylabel("y(t)")
        title("Solution to ODE y'=y^2cos(t+y)")
        save("./tests/imgs/ode/rk_45_variable_step.svg", "svg")
        clear()
    }

    // @TestCase
    func testOde2(): Unit {
        var (vec_x, vec_y) = solveIvp(Fun3, 0.0, 100.0, [0.2], "RK23", 1e-3, 1e-6)
        plot(vec_x, vec_y[0])
        xlabel("t")
        ylabel("y(t)")
        title("Solution to ODE y'=y^2cos(t+y)")
        save("./tests/imgs/ode/rk_23_variable_step.svg", "svg")
        clear()
    }

    // @TestCase
    func testOde3(): Unit {
        var (vec_x, vec_y) = solveIvp(Fun1, 0.0, 2.0, [1.0, 0.0], "RK45", 1e-3, 1e-6)
        plot(vec_x, vec_y[0])
        hold(true)
        plot(vec_x, vec_y[1])
        xlabel("t")
        ylabel("y(t)")
        title("Solution to ODE y1'=y2, y2'=2*y2-y1")
        save("./tests/imgs/ode/rk_45_Fun1_variable_step.svg", "svg")
        clear()
    }

    // @TestCase
    func testBM1(): Unit {
        let abstols = [1e-6, 1e-7, 1e-8, 1e-9, 1e-10, 1e-11, 1e-12, 1e-13]
        let reltols = [1e-3, 1e-4, 1e-5, 1e-6, 1e-7, 1e-8, 1e-9, 1e-10]
        var m = abstols.size
        print("testBM1\n")
        for (i in 0..m) {
            var start = DateTime.now()
            var (vec_x, vec_y) = solveIvp(BM_Func1, 0.0, 10.0, [1.0, 1.0], "RK23", abstols[i], reltols[i])
            var end = DateTime.now()
            var num = vec_x.size()
            var y0 = vec_y[0][num - 1]
            var test_sol: Float64 = 1.026344767572481
            var error = abs(test_sol- y0)
            print("${y0}  ${end - start}\n")
        }
    }

    // @TestCase
    func testBM2(): Unit {
        let abstols = [1e-6, 1e-7, 1e-8, 1e-9, 1e-10, 1e-11, 1e-12, 1e-13]
        let reltols = [1e-3, 1e-4, 1e-5, 1e-6, 1e-7, 1e-8, 1e-9, 1e-10]
        var m = abstols.size
        print("testBM2\n")
        for (i in 0..m) {
            var start = DateTime.now()
            var (vec_x, vec_y) = solveIvp(BM_Func2, 0.0, 100.0, [1.0, 0.0, 0.9], "RK23", abstols[i], reltols[i])
            var end = DateTime.now()
            var num = vec_x.size()
            var y0 = vec_y[0][num - 1]
            var test_sol: Float64 =  0.900595217062476
            var error = abs(test_sol- y0)
            print("${y0}  ${end - start}\n")
        }
    }

    // @TestCase
    func testBM3(): Unit {
        let abstols = [1e-7, 1e-8, 1e-9, 1e-10, 1e-11, 1e-12]
        let reltols = [1e-3, 1e-4, 1e-5, 1e-6, 1e-7, 1e-8]
        var m = abstols.size
        print("testBM3\n")
        for (i in 0..m) {
            var start = DateTime.now()
            var (vec_x, vec_y) = solveIvp(BM_Func3, 0.0, 1e5, [1.0, 0.0, 0.0], "RK45", abstols[i], reltols[i])
            var end = DateTime.now()
            var num = vec_x.size()
            var y0 = vec_y[0][num - 1]
            print("${y0}  ${end - start}\n")
        }
    }
}