package scientific.matplot

import std.math.*
import std.unittest.*
import std.unittest.testmacro.*

import scientific.numbers.*
import scientific.linear.*
import scientific.stats.random.*
import scientific.stats.normal.*

/* Type for x, y, z, size, color: Float64 */
foreign func c_scatter3(
    x: CPointer<Unit>, x_len: Int64,
    y: CPointer<Unit>, y_len: Int64,
    z: CPointer<Unit>, z_len: Int64,
    line_spec: CString): CPointer<Unit>

foreign func c_scatter3_with_size(
    x: CPointer<Unit>, y: CPointer<Unit>,
    z: CPointer<Unit>, size: CPointer<Unit>,
    len: Int64, line_spec: CString): CPointer<Unit>

foreign func c_scatter3_with_size_color(
    x: CPointer<Unit>, y: CPointer<Unit>, z: CPointer<Unit>,
    size: CPointer<Unit>, color: CPointer<Unit>, len: Int64, line_spec: CString): CPointer<Unit>


public func scatter3(x: Vector<Float64>, y: Vector<Float64>, z: Vector<Float64>,
                     line_spec!:String = ""): Line {
    let size = x.size()
    var cstr_line_spec = unsafe { LibC.mallocCString(line_spec) }
    let handle = unsafe { c_scatter3(x.ptr, size, y.ptr, size, z.ptr, size, cstr_line_spec) }
    unsafe { LibC.free(cstr_line_spec) }
    return Line(handle)
}

public func scatter3(x: Vector<Float64>, y: Vector<Float64>, z: Vector<Float64>, sz: Vector<Float64>,
                     line_spec!:String = ""): Line {
    let size = x.size()
    var cstr_line_spec = unsafe { LibC.mallocCString(line_spec) }
    let handle = unsafe { c_scatter3_with_size(x.ptr, y.ptr, z.ptr, sz.ptr, size, cstr_line_spec) }
    return Line(handle)
}

public func scatter3(x: Vector<Float64>, y: Vector<Float64>, z: Vector<Float64>, sz: Vector<Float64>,
                     color: Vector<Float64>, line_spec!:String = ""): Line {
    let size = x.size()
    var cstr_line_spec = unsafe { LibC.mallocCString(line_spec) }
    let handle = unsafe { c_scatter3_with_size_color(x.ptr, y.ptr, z.ptr, sz.ptr, color.ptr, size, cstr_line_spec) }
    unsafe { LibC.free(cstr_line_spec) }
    return Line(handle)
}

// The function below generates three concentric spheres
func generate_data(): (Vector<Float64>, Vector<Float64>, Vector<Float64>) {
    let r = linspace(-16.0, 16.0, num:17)
    let theta = r.apply({x => x / 16.0 * Float64.getPI()})
    let phi = r.apply({x => x / 16.0 * Float64.getPI() / 2.0})
    let sinphi = phi.apply({x => sin(x)})
    let cosphi = phi.apply({x => cos(x)})
    cosphi[0] = 0.0
    cosphi[16] = 0.0
    let sintheta = theta.apply({x => sin(x)})
    sintheta[0] = 0.0
    sintheta[16] = 0.0
    let costheta = theta.apply({x => cos(x)})
    let X = empty<Float64>(17, 17)
    let Y = empty<Float64>(17, 17)
    let Z = empty<Float64>(17, 17)
    for (i in 0..17) {
        for (j in 0..17) {
            X[i,j] = cosphi[i] * costheta[j]
            Y[i,j] = cosphi[i] * sintheta[j]
            Z[i,j] = sinphi[i]
        }
    }
    let X1d = reshape(X)
    let Y1d = reshape(Y)
    let Z1d = reshape(Z)
    let x = concat([X1d.apply({x => x * 0.5}), X1d.apply({x => x * 0.75}), X1d])
    let y = concat([Y1d.apply({y => y * 0.5}), Y1d.apply({y => y * 0.75}), Y1d])
    let z = concat([Z1d.apply({z => z * 0.5}), Z1d.apply({z => z * 0.75}), Z1d])
    return (x, y, z)
}

func testScatter3Plot1(): Unit {
    let (x, y, z) = generate_data()
    let sz = x.size() / 3
    scatter3(x, y, z)

    save("./tests/imgs/scatter3_plot/scatter3Plot_1.png", "png")
    clear()
}

func testScatter3Plot2(): Unit {
    let (x, y, z) = generate_data()
    let sz = x.size() / 3
    let sizes = concat([vector(sz, 16.0), vector(sz, 8.0), vector(sz, 2.0)])
    scatter3(x, y, z, sizes)
    view(40.0, 35.0)

    save("./tests/imgs/scatter3_plot/scatter3Plot_2.png", "png")
    clear()
}

func testScatter3Plot3(): Unit {
    let (x, y, z) = generate_data()
    let sz = x.size() / 3
    let sizes = concat([vector(sz, 16.0), vector(sz, 8.0), vector(sz, 2.0)])
    let colors = concat([vector(sz, 1.0), vector(sz, 2.0), vector(sz, 3.0)])
    scatter3(x, y, z, sizes, colors)
    view(40.0, 35.0)

    save("./tests/imgs/scatter3_plot/scatter3Plot_3.png", "png")
    clear()
}

public func testScatter3Plot4() {
    print("    - testScatter3Plot4\n")
    let r: Random = Random(1)
    let z = linspace(0.0, 4.0 * Float64.getPI(), num: 250)
    let x = z.apply({t => cos(t)}) * 2.0 + rand(r, 250)
    let y = z.apply({t => sin(t)}) * 2.0 + rand(r, 250)
    scatter3(x, y, z, line_spec: "filled")

    save("./tests/imgs/scatter3_plot/scatter3Plot_4.svg", "svg")
    clear()
}

public func testScatter3Plot5() {
    print("    - testScatter3Plot5\n")
    let r: Random = Random(1)
    let z = linspace(0.0, 4.0 * Float64.getPI(), num: 250)
    let x = z.apply({t => cos(t)}) * 2.0 + rand(r, 250)
    let y = z.apply({t => sin(t)}) * 2.0 + rand(r, 250)
    scatter3(x, y, z, line_spec: "*")

    save("./tests/imgs/scatter3_plot/scatter3Plot_5.svg", "svg")
    clear()
}

func testScatter3Plot6(): Unit {
    let (x, y, z) = generate_data()
    let sz = x.size() / 3
    let sizes = concat([vector(sz, 16.0), vector(sz, 8.0), vector(sz, 2.0)])
    let colors = concat([vector(sz, 1.0), vector(sz, 2.0), vector(sz, 3.0)])
    scatter3(x, y, z, sizes, colors).marker_face_color(0.0, 0.5, 0.5)
    view(40.0, 35.0)

    save("./tests/imgs/scatter3_plot/scatter3Plot_6.png", "png")
    clear()
}

public func testScatter3Plot() {
    print("  + testScatter3Plot\n")
    testScatter3Plot1()
    testScatter3Plot2()
    testScatter3Plot3()
    testScatter3Plot4()
    testScatter3Plot5()
    testScatter3Plot6()
}