package scientific.matplot

import std.math.*
import std.unittest.*
import std.unittest.testmacro.*

import scientific.numbers.*
import scientific.linear.*
import scientific.stats.random.*
import scientific.stats.normal.*

foreign func malloc(size: UIntNative): CPointer<Unit>

/* Return pointer to matplot::line.
   Type for x and y: Float64 */
foreign func c_plot(x: CPointer<Unit>, x_len: Int64, y: CPointer<Unit>, y_len: Int64, line_spec: CString): CPointer<Unit>


/* Return pointer to matplot::line.
   Type for indices: Int64 */
foreign func c_marker_indices(line: CPointer<Unit>, indices: CPointer<Unit>, indices_len: Int64): CPointer<Unit>
/* Return pointer to matplot::line. */
foreign func c_color_str(line: CPointer<Unit>, c: CString): CPointer<Unit>
foreign func c_color_rgb(line: CPointer<Unit>, r: Float32, g: Float32, b: Float32): CPointer<Unit>
foreign func c_color_rgba(line: CPointer<Unit>, r: Float32, g: Float32, b: Float32, a: Float32): CPointer<Unit>
foreign func c_line_width(line: CPointer<Unit>, line_width: Float32): CPointer<Unit>
foreign func c_line_style(line: CPointer<Unit>, line_style: CString): CPointer<Unit>
foreign func c_marker_str(line: CPointer<Unit>, marker: CString): CPointer<Unit>
foreign func c_marker_size(line: CPointer<Unit>, marker_size: Float32): CPointer<Unit>
foreign func c_marker_color_str(line: CPointer<Unit>, c: CString): CPointer<Unit>
foreign func c_marker_color_rgb(line: CPointer<Unit>, r: Float32, g: Float32, b: Float32): CPointer<Unit>
foreign func c_marker_face(line: CPointer<Unit>, v: Bool): CPointer<Unit>
foreign func c_marker_face_color_str(line: CPointer<Unit>, c: CString): CPointer<Unit>
foreign func c_marker_face_color_rgb(line: CPointer<Unit>, r: Float32, g: Float32, b: Float32): CPointer<Unit>

public enum MarkerStyle {
    | None
    | PlusSign
    | Circle
    | Asterisk
    | Point
    | Cross
    | Square
    | Diamond
    | UpwardPointingTriangle
    | DownwardPointingTriangle
    | RightPointingTriangle
    | LeftPointingTriangle
    | Pentagram
    | Hexagram

    func toString(): String {
        return match (this) {
            case None => "none"
            case PlusSign => "plus_sign"
            case Circle => "circle"
            case Asterisk => "asterisk"
            case Point => "point"
            case Cross => "cross"
            case Square => "square"
            case Diamond => "diamond"
            case UpwardPointingTriangle => "upward_pointing_triangle"
            case DownwardPointingTriangle => "downward_pointing_triangle"
            case RightPointingTriangle => "right_pointing_triangle"
            case LeftPointingTriangle => "left_pointing_triangle"
            case Pentagram => "pentagram"
            case Hexagram => "hexagram"
        }
    }
}

public enum AxisStyle {
    | Equal
    | Tight
    | Square

    func toString(): String {
        return match (this) {
            case Equal => "equal"
            case Tight => "tight"
            case Square => "square"
        }
    }
}


public open class Line {
    var ptr: CPointer<Unit> = CPointer<Unit>()

    /* Input is a pointer to matplot::line. */
    init(ptr: CPointer<Unit>) {
        this.ptr = ptr
    }

    public func marker_indices(indices: Vector<Int64>): Line {
        let size = indices.size()
        this.ptr = unsafe { c_marker_indices(this.ptr, indices.ptr, size) }
        return this
    }

    public open func color(c: String): Line {
        var cstr = unsafe { LibC.mallocCString(c) }
        this.ptr = unsafe { c_color_str(this.ptr, cstr) }
        unsafe { LibC.free(cstr) }
        return this
    }

    public func color(r: Float32, g: Float32, b: Float32): Line {
        this.ptr = unsafe { c_color_rgb(this.ptr, r, g, b) }
        return this
    }

    public func color(r: Float32, g: Float32, b: Float32, a: Float32): Line {
        this.ptr = unsafe { c_color_rgba(this.ptr, r, g, b, a) }
        return this
    }
    
    public open func line_width(line_width: Float32): Line {
        this.ptr = unsafe { c_line_width(this.ptr, line_width) }
        return this
    }

    public open func line_style(line_style: String): Line {
        var cstr_line_style = unsafe { LibC.mallocCString(line_style) }
        this.ptr = unsafe { c_line_style(this.ptr, cstr_line_style) }
        unsafe { LibC.free(cstr_line_style) }
        return this
    }

    public open func marker(c: String): Line {
        var cstr = unsafe { LibC.mallocCString(c) }
        this.ptr = unsafe { c_marker_str(this.ptr, cstr) }
        unsafe { LibC.free(cstr) }
        return this
    }
    
    public open func marker(style: MarkerStyle): Line {
        return this.marker(style.toString())
    }
    
    public open func marker_size(marker_size: Float32): Line {
        this.ptr = unsafe { c_marker_size(this.ptr, marker_size) }
        return this
    }

    public open func marker_color(c: String): Line {
        var cstr = unsafe { LibC.mallocCString(c) }
        this.ptr = unsafe { c_marker_color_str(this.ptr, cstr) }
        unsafe { LibC.free(cstr) }
        return this
    }

    public func marker_color(r: Float32, g: Float32, b: Float32): Line {
        this.ptr = unsafe { c_marker_color_rgb(this.ptr, r, g, b) }
        return this
    }

    public open func marker_face(v: Bool): Line {
        this.ptr = unsafe { c_marker_face(this.ptr, v) }
        return this
    }

    public open func marker_face_color(c: String): Line {
        var cstr = unsafe { LibC.mallocCString(c) }
        this.ptr = unsafe { c_marker_face_color_str(this.ptr, cstr) }
        unsafe { LibC.free(cstr) }
        return this
    }

    public func marker_face_color(r: Float32, g: Float32, b: Float32): Line {
        this.ptr = unsafe { c_marker_face_color_rgb(this.ptr, r, g, b) }
        return this
    }

    public func display_name(display_name: String): Unit {
        var cstr_display_name = unsafe { LibC.mallocCString(display_name) }
        unsafe { c_display_name(this.ptr, cstr_display_name) }
        unsafe { LibC.free(cstr_display_name) }
    }
}

public func plot(x: Vector<Float64>, y: Vector<Float64>, line_spec!:String = ""): Line {
    if (!(x.size() == y.size())) {
        throw IllegalArgumentException("plot: input are not vectors of the same size.")
    }
    let size = x.size()
    var cstr_line_spec = unsafe { LibC.mallocCString(line_spec) }
    let handle = unsafe { c_plot(x.ptr, size, y.ptr, size, cstr_line_spec) }
    unsafe { LibC.free(cstr_line_spec) }
    return Line(handle)
}

public func plot(y: Vector<Float64>, line_spec!:String = ""): Line {
    let size = y.size()
    let x = vector<Float64>(size, {i => Float64(i+1)})
    return plot(x, y, line_spec:line_spec)
}

public func plot<T>(x: Vector<T>, y: Vector<T>, line_spec!:String = ""): Line where T <: Real<T> {
    let x2 = x.toFloat64()
    let y2 = y.toFloat64()
    return plot(x2, y2, line_spec:line_spec)
}

public func plot<T>(y: Vector<T>, line_spec!:String = ""): Line where T <: Real<T> {
    let y2 = y.toFloat64()
    return plot(y2, line_spec:line_spec)
}

public func plot(axes: AxesType, x: Vector<Float64>, y: Vector<Float64>, line_spec!:String = ""): Line {
    if (!(x.size() == y.size())) {
        throw IllegalArgumentException("plot: input are not vectors of the same size.")
    }
    let size = x.size()
    var cstr_line_spec = unsafe { LibC.mallocCString(line_spec) }
    let handle = unsafe { c_axes_plot(axes.ptr, x.ptr, size, y.ptr, size, cstr_line_spec) }
    unsafe { LibC.free(cstr_line_spec) }
    return Line(handle)
}

public func plot(axes: AxesType, y: Vector<Float64>, line_spec!:String = ""): Line {
    let size = y.size()
    let x = vector<Float64>(size, {i => Float64(i+1)})
    return plot(axes, x, y, line_spec:line_spec)
}

public func plot<T>(axes: AxesType, x: Vector<T>, y: Vector<T>, line_spec!:String = ""): Line where T <: Real<T> {
    let x2 = x.toFloat64()
    let y2 = y.toFloat64()
    return plot(axes, x2, y2, line_spec:line_spec)
}

public func plot<T>(axes: AxesType, y: Vector<T>, line_spec!:String = ""): Line where T <: Real<T> {
    let y2 = y.toFloat64()
    return plot(axes, y2, line_spec:line_spec)
}


public func testPlot1() {
    print("    - testPlot1\n")
    let x = linspace<Float64>(0.0, 2.0 * Float64.getPI())
    let y = x.apply({t => sin(t)})
    plot(x, y, line_spec:"-o")
    hold(true)
    plot(x, y.apply({y => -y}), line_spec:"--xr")
    plot(x, x.apply({x => x / Float64.getPI() - 1.0}), line_spec:"-:gs")
    plot(vector([1.0, 0.7, 0.4, 0.0, -0.4, -0.7, -1.0]), line_spec:"k")
    save("./tests/imgs/line_plot/plot_1.svg", "svg")
    clear()
}

public func testPlot2() {
    print("    - testPlot2\n")
    let y = [
        vector([16.0, 5.0, 9.0, 4.0]),
        vector([2.0, 11.0, 7.0, 14.0]),
        vector([3.0, 10.0, 6.0, 15.0]),
        vector([13.0, 8.0, 12.0, 1.0])
    ]
    plot(y[0])
    hold(true)
    plot(y[1])
    plot(y[2])
    plot(y[3])
    save("./tests/imgs/line_plot/plot_2.svg", "svg")
    clear()
}

public func testPlot3() {
    print("    - testPlot3\n")
    let x = linspace<Float64>(0.0, 2.0 * Float64.getPI())
    let y1 = x.apply({t => sin(t)})
    let y2 = x.apply({x => sin(x - 0.25)})
    let y3 = x.apply({x => sin(x - 0.5)})
    plot(x, y1)
    hold(true)
    plot(x, y2, line_spec:"--")
    plot(x, y3, line_spec:":")
    save("./tests/imgs/line_plot/plot_3.svg", "svg")
    clear()
}

public func testPlot4() {
    print("    - testPlot4\n")
    let x = linspace<Float64>(0.0, 2.0 * Float64.getPI())
    let y1 = x.apply({t => sin(t)})
    let y2 = x.apply({x => sin(x - 0.25)})
    let y3 = x.apply({x => sin(x - 0.5)})
    plot(x, y1, line_spec:"g")
    hold(true)
    plot(x, y2, line_spec:"b--o")
    plot(x, y3, line_spec:"c*")
    save("./tests/imgs/line_plot/plot_4.svg", "svg")
    clear()
}

public func testPlot5() {
    print("    - testPlot5\n")
    let x = linspace<Float64>(0.0, 10.0)
    let y = x.apply({t => sin(t)})
    plot(x, y, line_spec:"-o").marker_indices(vector<Int64>(
        [0, 5, 10, 15, 20, 25, 30, 35, 40, 45,
         50, 55, 60, 65, 70, 75, 80, 85, 90, 95]))
    save("./tests/imgs/line_plot/plot_5.svg", "svg")
    clear()
}

public func testPlot6() {
    print("    - testPlot6\n")
    let x = linspace<Float64>(-Float64.getPI(), Float64.getPI(), num:20)
    let y = x.apply({x => tan(sin(x)) - sin(tan(x))})
    var line = plot(x, y, line_spec:"--gs").line_width(2.0)
        .marker_size(10.0).marker_color("b")
        .marker_face_color(0.5, 0.5, 0.5)
    save("./tests/imgs/line_plot/plot_6.svg", "svg")
    clear()
}

public func testPlot7() {
    print("    - testPlot7\n")
    let x = linspace<Float64>(0.0, 10.0, num:150)
    let y = x.apply({x => cos(5.0 * x)})
    plot(x, y).color(0.0, 0.7, 0.9)
    title("2-D Line Plot")
    xlabel("x")
    ylabel("cos(5x)")
    save("./tests/imgs/line_plot/plot_7.svg", "svg")
    clear()
}

public func testPlot8() {
    print("    - testPlot8\n")
    let x = linspace<Float64>(0.0, 180.0, num:7)
    let y = vector<Float64>([0.8, 0.9, 0.1, 0.9, 0.6, 0.1, 0.3])
    plot(x, y)
    title("Time Plot")
    xlabel("Time")
    yrange(0.0, 1.0)
    xticks(vector<Float64>([0.0, 30.0, 60.0, 90.0, 120.0, 150.0, 180.0]))
    xticklabels(["00:00s", "00:30", "01:00", "01:30", "02:00", "02:30", "03:00"])
    save("./tests/imgs/line_plot/plot_8.svg", "svg")
    clear()
}

public func testPlot9() {
    print("    - testPlot9\n")
    let x = linspace<Float64>(0.0, 3.0)
    let y1 = x.apply({x => sin(5.0 * x)})
    let y2 = x.apply({x => sin(15.0 * x)})

    tiledlayout(2, 1)
    let ax1 = nexttile()
    plot(ax1, x, y1)
    title(ax1, "Top Plot")
    ylabel(ax1, "sin(5x)")

    let ax2 = nexttile()
    plot(ax2, x, y2)
    title(ax2, "Bottom Plot")
    ylabel(ax2, "sin(15x)")
    save("./tests/imgs/line_plot/plot_9.svg", "svg")
    clear()
}

public func testPlot10() {
    print("    - testPlot10\n")
    let x = linspace<Float64>(-2.0 * Float64.getPI(), 3.0)
    let y1 = x.apply({t => sin(t)})
    let y2 = x.apply({t => cos(t)})

    let line1 = plot(x, y1)
    hold(true)
    let line2 = plot(x, y2)
    line1.line_width(2.0)
    line2.marker(MarkerStyle.Asterisk)
    save("./tests/imgs/line_plot/plot_10.svg", "svg")
    clear()
}

public func testPlot11() {
    print("    - testPlot11\n")
    let r: Float64 = 2.0
    let xc: Float64 = 4.0
    let yc: Float64 = 3.0
    let theta = linspace<Float64>(0.0, 2.0 * Float64.getPI())
    let x = theta.apply({theta => r * cos(theta) + xc})
    let y = theta.apply({theta => r * sin(theta) + yc})
    plot(x, y)
    axis(AxisStyle.Equal)
    save("./tests/imgs/line_plot/plot_11.svg", "svg")
    clear()
}

public func testPlot12() {
    print("    - testPlot12\n")
    let y = vector<Int32>([2, 4, 7, 7, 6, 3, 9, 7, 3, 5])
    plot(y)
    save("./tests/imgs/line_plot/plot_12.svg", "svg")
    clear()
}

public func testPlot() {
    print("  + testPlot\n")
    testPlot1()
    testPlot2()
    testPlot3()
    testPlot4()
    testPlot5()
    testPlot6()
    testPlot7()
    testPlot8()
    testPlot9()
    testPlot10()
    testPlot11()
    testPlot12()
}