package scientific.matplot

import std.math.*
import std.unittest.*
import std.unittest.testmacro.*

import scientific.numbers.*
import scientific.linear.*
import scientific.stats.random.*
import scientific.stats.normal.*

/* Contour. x, y, z, levels have type Float64. */
foreign func c_contour(
    x: CPointer<Unit>, y: CPointer<Unit>, z: CPointer<Unit>,
    row: Int64, col: Int64, line_spec: CString): CPointer<Unit>

foreign func c_contour_nlevels(
    x: CPointer<Unit>, y: CPointer<Unit>, z: CPointer<Unit>,
    row: Int64, col: Int64, n_levels: Int64, line_spec: CString): CPointer<Unit>

foreign func c_contour_levels(
    x: CPointer<Unit>, y: CPointer<Unit>, z: CPointer<Unit>,
    row: Int64, col: Int64, levels: CPointer<Unit>, n_levels: Int64,
    line_spec: CString): CPointer<Unit>

/* Filled contours. x, y, z, levels have type Float64. */
foreign func c_contourf(
    x: CPointer<Unit>, y: CPointer<Unit>, z: CPointer<Unit>,
    row: Int64, col: Int64, line_spec: CString): CPointer<Unit>

foreign func c_contourf_nlevels(
    x: CPointer<Unit>, y: CPointer<Unit>, z: CPointer<Unit>,
    row: Int64, col: Int64, n_levels: Int64, line_spec: CString): CPointer<Unit>

foreign func c_contourf_levels(
    x: CPointer<Unit>, y: CPointer<Unit>, z: CPointer<Unit>,
    row: Int64, col: Int64, levels: CPointer<Unit>, n_levels: Int64,
    line_spec: CString): CPointer<Unit>

/* Other functions on contours. levels have type Float64 */
foreign func c_contour_text(contour: CPointer<Unit>, b: Bool): CPointer<Unit>
foreign func c_contour_line_width(contour: CPointer<Unit>, width: Float32): CPointer<Unit>
foreign func c_contour_line_style(contour: CPointer<Unit>, style: CString): CPointer<Unit>
foreign func c_contour_font_size(contour: CPointer<Unit>, font_size: Float32): CPointer<Unit>
foreign func c_contour_font_color(contour: CPointer<Unit>, color: CString): CPointer<Unit>
foreign func c_contour_font_weight(contour: CPointer<Unit>, weight: CString): CPointer<Unit>
foreign func c_contour_set_levels(contour: CPointer<Unit>, levels: CPointer<Unit>, n_levels: Int64): CPointer<Unit>
foreign func c_contour_set_nlevels(contour: CPointer<Unit>, n_levels: Int64): CPointer<Unit>
foreign func c_contour_filled(contour: CPointer<Unit>, filled: Bool): CPointer<Unit>
foreign func c_contour_colormap_line_when_filled(contour: CPointer<Unit>, b: Bool): CPointer<Unit>

public class Contours {
    var ptr: CPointer<Unit> = CPointer<Unit>()

    /* Input is a pointer to matplot::contours */
    init(ptr: CPointer<Unit>) {
        this.ptr = ptr
    }

    public func contour_text(b: Bool): Contours {
        this.ptr = unsafe { c_contour_text(this.ptr, b) }
        return this
    }

    public func line_width(width: Float32): Contours {
        this.ptr = unsafe { c_contour_line_width(this.ptr, width) }
        return this
    }

    public func line_style(style: String): Contours {
        var cstr_style = unsafe { LibC.mallocCString(style) }
        this.ptr = unsafe { c_contour_line_style(this.ptr, cstr_style) }
        unsafe { LibC.free(cstr_style) }
        return this
    }

    public func font_size(font_size: Float32): Contours {
        this.ptr = unsafe { c_contour_font_size(this.ptr, font_size) }
        return this
    }

    public func font_color(color: String): Contours {
        var cstr_color = unsafe { LibC.mallocCString(color) }
        this.ptr = unsafe { c_contour_font_color(this.ptr, cstr_color) }
        unsafe { LibC.free(cstr_color) }
        return this
    }

    public func font_weight(weight: String): Contours {
        var cstr_weight = unsafe { LibC.mallocCString(weight) }
        this.ptr = unsafe { c_contour_font_weight(this.ptr, cstr_weight) }
        unsafe { LibC.free(cstr_weight) }
        return this
    }

    public func levels(levels: Vector<Float64>): Contours {
        this.ptr = unsafe { c_contour_set_levels(this.ptr, levels.ptr, levels.size()) }
        return this
    }

    public func n_levels(n_levels: Int64): Contours {
        this.ptr = unsafe { c_contour_set_nlevels(this.ptr, n_levels) }
        return this
    }

    public func filled(filled: Bool): Contours {
        this.ptr = unsafe { c_contour_filled(this.ptr, filled) }
        return this
    }

    public func colormap_line_when_filled(b: Bool): Contours {
        this.ptr = unsafe { c_contour_colormap_line_when_filled(this.ptr, b) }
        return this
    }
}

public func contour(x: Matrix<Float64>, y: Matrix<Float64>, z: Matrix<Float64>,
                    line_spec!:String = ""): Contours {
    if (x.shape() != y.shape() || x.shape() != z.shape()) {
        throw IllegalArgumentException("contour: dimension mismatch")
    }
    let row = x.getRows()
    let col = x.getCols()
    var cstr_line_spec = unsafe { LibC.mallocCString(line_spec) }
    let handle = unsafe { c_contour(
        x.ptr, y.ptr, z.ptr, row, col, cstr_line_spec) }
    unsafe { LibC.free(cstr_line_spec) }
    return Contours(handle)
}

public func contour(x: Matrix<Float64>, y: Matrix<Float64>, z: Matrix<Float64>,
                    n_levels: Int64, line_spec!:String = ""): Contours {
    if (x.shape() != y.shape() || x.shape() != z.shape()) {
        throw IllegalArgumentException("contour: dimension mismatch")
    }
    let row = x.getRows()
    let col = x.getCols()
    var cstr_line_spec = unsafe { LibC.mallocCString(line_spec) }
    let handle = unsafe { c_contour_nlevels(
        x.ptr, y.ptr, z.ptr, row, col, n_levels, cstr_line_spec) }
    unsafe { LibC.free(cstr_line_spec) }
    return Contours(handle)
}

public func contour(x: Matrix<Float64>, y: Matrix<Float64>, z: Matrix<Float64>,
                    levels: Vector<Float64>, line_spec!:String = ""): Contours {
    if (x.shape() != y.shape() || x.shape() != z.shape()) {
        throw IllegalArgumentException("contour: dimension mismatch")
    }
    let row = x.getRows()
    let col = x.getCols()
    let n_levels = levels.size()
    var cstr_line_spec = unsafe { LibC.mallocCString(line_spec) }
    let handle = unsafe { c_contour_levels(
        x.ptr, y.ptr, z.ptr, row, col, levels.ptr, n_levels, cstr_line_spec) }
    unsafe { LibC.free(cstr_line_spec) }
    return Contours(handle)
}

public func contourf(x: Matrix<Float64>, y: Matrix<Float64>, z: Matrix<Float64>,
                     line_spec!:String = ""): Contours {
    if (x.shape() != y.shape() || x.shape() != z.shape()) {
        throw IllegalArgumentException("contour: dimension mismatch")
    }
    let row = x.getRows()
    let col = x.getCols()
    var cstr_line_spec = unsafe { LibC.mallocCString(line_spec) }
    let handle = unsafe { c_contourf(
        x.ptr, y.ptr, z.ptr, row, col, cstr_line_spec) }
    unsafe { LibC.free(cstr_line_spec) }
    return Contours(handle)
}

public func contourf(x: Matrix<Float64>, y: Matrix<Float64>, z: Matrix<Float64>,
                     n_levels: Int64, line_spec!:String = ""): Contours {
    if (x.shape() != y.shape() || x.shape() != z.shape()) {
        throw IllegalArgumentException("contour: dimension mismatch")
    }
    let row = x.getRows()
    let col = x.getCols()
    var cstr_line_spec = unsafe { LibC.mallocCString(line_spec) }
    let handle = unsafe { c_contourf_nlevels(
        x.ptr, y.ptr, z.ptr, row, col, n_levels, cstr_line_spec) }
    unsafe { LibC.free(cstr_line_spec) }
    return Contours(handle)
}

public func contourf(x: Matrix<Float64>, y: Matrix<Float64>, z: Matrix<Float64>,
                     levels: Vector<Float64>, line_spec!:String = ""): Contours {
    if (x.shape() != y.shape() || x.shape() != z.shape()) {
        throw IllegalArgumentException("contour: dimension mismatch")
    }
    let row = x.getRows()
    let col = x.getCols()
    let n_levels = levels.size()
    var cstr_line_spec = unsafe { LibC.mallocCString(line_spec) }
    let handle = unsafe { c_contourf_levels(
        x.ptr, y.ptr, z.ptr, row, col, levels.ptr, n_levels, cstr_line_spec) }
    unsafe { LibC.free(cstr_line_spec) }
    return Contours(handle)
}

public func fcontour(
    f: (Float64, Float64) -> Float64,
    xmin!:Float64 = -5.0, xmax!:Float64 = 5.0, ymin!:Float64 = -5.0, ymax!:Float64 = 5.0,
    line_spec!:String = "") {
    let x = linspace(xmin, xmax)
    let y = linspace(ymin, ymax)
    let (X, Y, Z) = meshgrid(x, y, f)
    return contour(X, Y, Z, line_spec:line_spec)
}

public func fcontour(
    f: (Float64, Float64) -> Float64,
    xmin!:Float64 = -5.0, xmax!:Float64 = 5.0, ymin!:Float64 = -5.0, ymax!:Float64 = 5.0,
    n_levels!:Int64, line_spec!:String = "") {
    let x = linspace(xmin, xmax)
    let y = linspace(ymin, ymax)
    let (X, Y, Z) = meshgrid(x, y, f)
    return contour(X, Y, Z, n_levels, line_spec:line_spec)
}

public func fcontour(
    f: (Float64, Float64) -> Float64,
    xmin!:Float64 = -5.0, xmax!:Float64 = 5.0, ymin!:Float64 = -5.0, ymax!:Float64 = 5.0,
    levels!:Vector<Float64>, line_spec!:String = "") {
    let x = linspace(xmin, xmax)
    let y = linspace(ymin, ymax)
    let (X, Y, Z) = meshgrid(x, y, f)
    return contour(X, Y, Z, levels, line_spec:line_spec)
}

public func testContour1() {
    let x = linspace(-2.0 * Float64.getPI(), 2.0 * Float64.getPI())
    let y = linspace(0.0, 4.0 * Float64.getPI())
    let (X, Y, Z) = meshgrid(x, y, {x: Float64, y: Float64 => sin(x) + cos(y)})
    contour(X, Y, Z)
    save("./tests/imgs/contour/contour_1.svg", "svg")
    clear()
}

public func testContour2() {
    let x = linspace(-3.0, 3.0, num: 50)
    let (X, Y, Z) = meshgrid(x, x, peaks)
    contour(X, Y, Z, 20)
    save("./tests/imgs/contour/contour_2.svg", "svg")
    clear()
}

public func testContour3() {
    let x = linspace(-3.0, 3.0, num: 50)
    let (X, Y, Z) = meshgrid(x, x, peaks)
    contour(X, Y, Z, vector([1.0]))
    save("./tests/imgs/contour/contour_3.svg", "svg")
    clear()
}

public func testContour4() {
    let x = linspace(-3.0, 3.0, num: 50)
    let (X, Y, Z) = meshgrid(x, x, peaks)
    contour(X, Y, Z, line_spec:"--")
    save("./tests/imgs/contour/contour_4.svg", "svg")
    clear()
}

public func testContour5() {
    let x = linspace(-3.0, 3.0, num: 46)
    let y = linspace(-3.0, 3.0, num: 46)
    let (X, Y, Z) = meshgrid(x, y, {x: Float64, y: Float64 =>
        x * exp(-pow(x, 2.0) - pow(y, 2.0))})
    contour(X, Y, Z).contour_text(true)
    save("./tests/imgs/contour/contour_5.svg", "svg")
    clear()
}

public func testContour6() {
    let x = linspace(-3.0, 3.0, num: 50)
    let (X, Y, Z) = meshgrid(x, x, peaks)
    contour(X, Y, Z).line_width(3.0)
    save("./tests/imgs/contour/contour_6.svg", "svg")
    clear()
}

public func testContour7() {
    let x = linspace(-3.0, 3.0, num: 50)
    let (X, Y, Z) = meshgrid(x, x, peaks)
    let c: Contours = contour(X, Y, Z)
    c.contour_text(true)
    c.font_size(15.0)
    c.font_color("blue")
    c.font_weight("bold")
    save("./tests/imgs/contour/contour_7.svg", "svg")
    clear()
}

public func testContour8() {
    let x = linspace(-3.0, 3.0, num: 49)
    let (X, Y, Z) = meshgrid(x, x, peaks)
    for (i in 0..49) {
        Z[i,25] = Float64.NaN
    }
    contour(X, Y, Z).contour_text(true)
    save("./tests/imgs/contour/contour_8.svg", "svg")
    clear()
}

public func testFilledContour1() {
    let x = linspace(-3.0, 3.0, num: 50)
    let (X, Y, Z) = meshgrid(x, x, peaks)
    contourf(X, Y, Z)
    save("./tests/imgs/contourf/contourf_1.svg", "svg")
    clear()
}

public func testFilledContour2() {
    let x = linspace(-2.0 * Float64.getPI(), 2.0 * Float64.getPI(), num: 50)
    let y = linspace(0.0, 4.0 * Float64.getPI(), num: 50)
    let (X, Y, Z) = meshgrid(x, y, {x: Float64, y: Float64 => sin(x) + cos(y)})
    contourf(X, Y, Z, 10)
    save("./tests/imgs/contourf/contourf_2.svg", "svg")
    clear()
}

public func testFilledContour3() {
    let x = linspace(-3.0, 3.0, num: 50)
    let (X, Y, Z) = meshgrid(x, x, peaks)
    contourf(X, Y, Z, vector([2.0, 3.0])).contour_text(true)
    save("./tests/imgs/contourf/contourf_3.svg", "svg")
    clear()
}

public func testFilledContour4() {
    let x = linspace(-3.0, 3.0, num: 50)
    let (X, Y, Z) = meshgrid(x, x, peaks)
    contourf(X, Y, Z, vector([2.0]))
    save("./tests/imgs/contourf/contourf_4.svg", "svg")
    clear()
}

public func testFilledContour5() {
    let x = linspace(-3.0, 3.0, num: 50)
    let (X, Y, Z) = meshgrid(x, x, peaks)
    contourf(X, Y, Z, line_spec:"--")
    save("./tests/imgs/contourf/contourf_5.svg", "svg")
    clear()
}

public func testFilledContour6() {
    let x = linspace(-3.0, 3.0, num: 50)
    let (X, Y, Z) = meshgrid(x, x, peaks)
    contourf(X, Y, Z).line_width(3.0)
    save("./tests/imgs/contourf/contourf_6.svg", "svg")
    clear()
}

public func testFunctionContour1() {
    let f = {x: Float64, y: Float64 => sin(x) + cos(y)}
    fcontour(f)
    save("./tests/imgs/fcontour/fcontour_1.svg", "svg")
    clear()
}

public func testFunctionContour2() {
    fcontour({x, y => erf(x) + cos(y)}, xmin:-5.0, xmax:0.0, ymin:-5.0, ymax:5.0,
             levels: linspace(-2.0, 2.0, num: 9))
    hold(true)
    fcontour({x, y => sin(x) + cos(y)}, xmin:0.0, xmax:5.0, ymin:-5.0, ymax:5.0,
             levels: linspace(-2.0, 2.0, num: 9))
    hold(false)
    axis(-5.0, 5.0, -5.0, 5.0)
    grid(true)
    save("./tests/imgs/fcontour/fcontour_2.svg", "svg")
    clear()
}

public func testFunctionContour3() {
    let f = {x: Float64, y: Float64 => pow(x, 2.0) - pow(y, 2.0)}
    fcontour(f, line_spec:"--").line_width(2.0)
    save("./tests/imgs/fcontour/fcontour_3.svg", "svg")
    clear()
}

public func testFunctionContour4() {
    fcontour({x, y => sin(x) + cos(y)})
    hold(true)
    fcontour({x, y => x - y})
    hold(false)
    save("./tests/imgs/fcontour/fcontour_4.svg", "svg")
    clear()
}

public func testFunctionContour5() {
    let f = {x: Float64, y: Float64 =>
        exp(-pow(x / 3.0, 2.0) - pow(y / 3.0, 2.0)) +
        exp(-pow(x + 2.0, 2.0) - pow(y + 2.0, 2.0))
    }
    fcontour(f).line_width(1.0).line_style("--")
               .levels(vector([1.0, 0.9, 0.8, 0.2, 0.1]))
    save("./tests/imgs/fcontour/fcontour_5.svg", "svg")
    clear()
}

public func testFunctionContour6() {
    let f = {x: Float64, y: Float64 =>
        erf(pow(y + 2.0, 3.0)) -
        exp(-0.65 * (pow(x - 2.0, 2.0) + pow(y - 2.0, 2.0)))
    }
    fcontour(f).filled(true).colormap_line_when_filled(true)
    save("./tests/imgs/fcontour/fcontour_6.svg", "svg")
    clear()
}

public func testFunctionContour7() {
    let f = {x: Float64, y: Float64 =>
        erf(pow(y + 2.0, 3.0)) -
        exp(-0.65 * (pow(x - 2.0, 2.0) + pow(y - 2.0, 2.0)))
    }
    fcontour(f, n_levels:25).filled(true).colormap_line_when_filled(true)
    save("./tests/imgs/fcontour/fcontour_7.svg", "svg")
    clear()
}

public func testFunctionContour8() {
    fcontour({x, y => sin(x) + cos(y)}).levels(vector([-1.0, 0.0, 1.0]))
    save("./tests/imgs/fcontour/fcontour_8.svg", "svg")
    clear()
}

public func testFunctionContour9() {
    let rastrigin = {x: Float64, y: Float64 =>
        20.0 + pow(x, 2.0) - 10.0 * cos(2.0 * Float64.getPI() * x) + pow(y, 2.0) -
        10.0 * cos(2.0 * Float64.getPI() * y)
    }
    fcontour(rastrigin).filled(true)
    save("./tests/imgs/fcontour/fcontour_9.svg", "svg")
    clear()
}

public func testFunctionContour10() {
    let ackley = {x: Float64, y: Float64 =>
        -20.0 * exp(-0.2 * sqrt(0.5 * (pow(x, 2.0) + pow(y, 2.0)))) -
        exp(0.5 * (cos(2.0 * Float64.getPI() * x) + cos(2.0 * Float64.getPI() * y))) + exp(1.0) + 20.0
    }
    fcontour(ackley).n_levels(10).filled(true)
    save("./tests/imgs/fcontour/fcontour_10.svg", "svg")
    clear()
}

public func testFunctionContour11() {
    let rosenbrock = {x: Float64, y: Float64 =>
        100.0 * pow(y - pow(x, 2.0), 2.0) + pow(1.0 - x, 2.0)
    }
    fcontour(rosenbrock).n_levels(10).filled(true)
    save("./tests/imgs/fcontour/fcontour_11.svg", "svg")
    clear()
}

public func testContour() {
    testContour1()
    testContour2()
    testContour3()
    testContour4()
    testContour5()
    testContour6()
    testContour7()
    testContour8()
}

public func testFilledContour() {
    testFilledContour1()
    testFilledContour2()
    testFilledContour3()
    testFilledContour4()
    testFilledContour5()
    testFilledContour6()
}

public func testFunctionContour() {
    testFunctionContour1()
    testFunctionContour2()
    testFunctionContour3()
    testFunctionContour4()
    testFunctionContour5()
    testFunctionContour6()
    testFunctionContour7()
    testFunctionContour8()
    testFunctionContour9()
    testFunctionContour10()
    testFunctionContour11()
}