package scientific.matplot

import std.collection.*
import std.math.*
import std.unittest.*
import std.unittest.testmacro.*

import scientific.numbers.*
import scientific.linear.*
import scientific.stats.random.*
import scientific.stats.normal.*

/* Type for x and color: Float64 */
foreign func c_parallelplot(x: CPointer<Unit>, x_row: Int64, x_col: Int64): CPointer<Unit>
foreign func c_parallelplot_color(x: CPointer<Unit>, x_row: Int64, x_col: Int64,
                                  color: CPointer<Unit>, col_size: Int64): CPointer<Unit>

/* Input is ParallelLines, return an array of AxisType. */
foreign func c_parallel_lines_axis(plines: CPointer<Unit>): CPointer<CPointer<Unit>>

public func parallelplot(x: Matrix<Float64>): ParallelLines {
    let handle = unsafe { c_parallelplot(x.ptr, x.getRows(), x.getCols()) }
    let pline = ParallelLines(handle)
    pline.num_axis = x.getRows()
    return pline
}

public func parallelplot(x: Matrix<Float64>, color: Vector<Float64>): ParallelLines {
    let handle = unsafe { c_parallelplot_color(x.ptr, x.getRows(), x.getCols(), color.ptr, color.size()) }
    let pline = ParallelLines(handle)
    pline.num_axis = x.getRows()
    return pline
}

public class ParallelLines {
    var ptr: CPointer<Unit> = CPointer<Unit>()
    var num_axis: Int64 = 0

    init(ptr: CPointer<Unit>) {
        this.ptr = ptr
    }

    public func axis(): Array<AxisType> {
        let handles = unsafe { c_parallel_lines_axis(ptr) }
        let res = ArrayList<AxisType>()
        for (i in 0..num_axis) {
            res.add(AxisType(unsafe {handles.read(i)}))
        }
        return res.toArray()
    }
}

public func testParallelPlot1() {
    let x = empty<Float64>(3, 100)
    var r: Random = Random(0)
    x[0] = concat(rand(r, 50, 78.0, 100.0), rand(r, 50, 65.0, 91.0))
    x[1] = concat(vector<Float64>(50, 1.0), vector<Float64>(50, 0.0))
    x[2] = concat(rand(r, 50, 122.0, 140.0), rand(r, 50, 105.0, 131.0))

    let p: ParallelLines = parallelplot(x)

    xticks(vector<Float64>([1.0, 2.0, 3.0]))
    xticklabels(["f_1", "f_2", "f_3"])

    p.axis()[1].tick_values(vector<Float64>([0.0, 1.0]))
    p.axis()[1].ticklabels(["false", "true"])
    save("./tests/imgs/parallel_plot/parallel_plot1.svg", "svg")
    clear()
}

public func testParallelPlot2() {
    let x = empty<Float64>(4, 100)
    var r: Random = Random(0)
    x[0] = concat(rand(r, 50, 78.0, 200.0), rand(r, 50, 65.0, 91.0))
    x[1] = concat(vector<Float64>(50, 1.0), vector<Float64>(50, 0.0))
    x[2] = concat(rand(r, 50, 122.0, 140.0), rand(r, 50, 105.0, 131.0))
    x[3] = concat([vector<Float64>(25, 3.0), vector<Float64>(50, 1.0),
                   vector<Float64>(25, 2.0)])

    let p: ParallelLines = parallelplot(x, x[3])

    xticks(vector<Float64>([1.0, 2.0, 3.0, 4.0]))
    xticklabels(["f_1", "f_2", "f_3", "f_4"])

    p.axis()[1].tick_values(vector<Float64>([0.0, 1.0]))
    p.axis()[1].ticklabels(["false", "true"])

    p.axis()[3].tick_values(vector<Float64>([1.0, 2.0, 3.0]))
    p.axis()[3].ticklabels(["low", "medium", "high"])
    save("./tests/imgs/parallel_plot/parallel_plot2.svg", "svg")
    clear()
}

public func testParallelPlot3() {
    let x = empty<Float64>(4, 100)
    var r: Random = Random(0)
    x[0] = randn(r, 100, 50.0, 200.0)
    let p = rand(r, 100, -30.0, 30.0)
    let p2 = rand(r, 100, -30.0, 30.0)
    x[1] = x[0] + p
    for (i in 0..100) {
        if (x[0, i] > 50.0) {
            x[2, i] = 1.0
        } else {
            x[2, i] = -1.0
        }
        x[3, i] = cos(p2[i])
    }
    let color = x[2]
    parallelplot(x, color)
    save("./tests/imgs/parallel_plot/parallel_plot3.svg", "svg")
    clear()
}

public func testParallelPlot() {
    print("  + testParallelPlot\n")
    testParallelPlot1()
    testParallelPlot2()
    testParallelPlot3()
}