package scientific.stats.frequencystat

import std.collection.*
import std.unittest.*
import std.unittest.testmacro.*

import scientific.stats.summary.variance
import scientific.stats.summary.std
import scientific.numbers.*
import scientific.linear.*

public class BinnedStatisticRes {
    let statistic: Vector<Float64>
    let bin_edges: Vector<Float64>
    let binnumber: Vector<Int64>

    public init(statistic: Vector<Float64>, bin_edges: Vector<Float64>, binnumber: Vector<Int64>) {
        this.statistic = statistic
        this.bin_edges = bin_edges
        this.binnumber = binnumber
    }
}

func sup_statistic(list: ArrayList<Float64>, statistic_type: String): Float64 {
    var res = 0.0
    let n = list.size

    let vec = vector<Float64>(list.toArray())
    if (statistic_type == "mean") {
        res = sum(vec) / Float64(vec.size())
    } else if (statistic_type == "std") {
        res = std(vec) 
    } else if (statistic_type == "median") {
        let vec_sorted = sorted(vec)
        if (n % 2 == 1) {
            res = vec_sorted[n / 2]
        } else {
            res = 0.5 * Float64(vec_sorted[n / 2 - 1] + vec_sorted[n / 2])
        }

    } else if (statistic_type == "count") {
        res = Float64(n)
    } else if (statistic_type == "sum") {
        res = sum(vec)
    } else if (statistic_type == "min") {
        res = min(vec)
    } else if (statistic_type == "max") {
        res = max(vec)
    }

    return res 
}

public func binned_statistic(data: Vector<Float64>, values: Vector<Float64>, range: (Float64, Float64), 
                             statistic_type!: String = "mean", numbins!: Int64 = 10): BinnedStatisticRes {
    let n = data.size()
    let binnumber = vector<Int64>(n, 0) // index begin with 1

    let range_l = range[0]
    let range_r = range[1]
    let binsize = (range_r - range_l) / Float64(numbins)
    let bin_edges = vector<Float64>(numbins + 1, 0.0)
    for (i in 0..(numbins + 1)) {
        bin_edges[i] = range_l + Float64(i) * binsize 
    }

    for (i in 0..n) {
        for (j in 0..numbins) {
            if (data[i] >= bin_edges[j] && data[i] < bin_edges[j + 1]) {
                binnumber[i] = (j + 1)
                break
            }
            if (j == (numbins - 1) && data[i] == bin_edges[j + 1]) {
                binnumber[i] = (j + 1)
                break
            }
        }
    }

    let statistic = vector<Float64>(numbins, 0.0)

    let arrlist = Array<ArrayList<Float64>>(numbins, {_ => ArrayList<Float64>()}) 
    for (i in 0..numbins) {
        arrlist[i] = ArrayList<Float64>()
    }

    for (i in 0..n) {
        let arr_index = binnumber[i] - 1
        arrlist[arr_index].add(values[i])
    }

    for (i in 0..numbins) {
        if (arrlist[i].size == 0) {
            statistic[i] = Float64.NaN
        }
        statistic[i] = sup_statistic(arrlist[i], statistic_type)
    }

    return BinnedStatisticRes(statistic, bin_edges, binnumber)
}

public func binned_statistic(data: Vector<Float64>, values: Vector<Float64>, 
                             statistic_type!: String = "mean", numbins!: Int64 = 10): BinnedStatisticRes {
    if (data.size() == 0) {
        throw IllegalArgumentException("binned_statistic: inpute null.")
    }

    let l = min(data)
    let r = max(data)

    let range = (l, r)

    return binned_statistic(data, values, range, statistic_type: statistic_type, numbins: numbins)
}

@Test
public class TestBinnedStatistic {
    @TestCase
    func testBinned_statistic(): Unit {
        let data = vector<Float64>([20.0, 2.0, 7.0, 1.0, 34.0])
        let value = vector<Float64>([0.0, 1.0, 2.0, 3.0, 4.0])
        let res1 = binned_statistic(data, value, statistic_type: "median", numbins: 2)
        @Assert(approxEqual(res1.statistic, vector([2.0, 2.0])))
        @Assert(approxEqual(res1.bin_edges, vector([1.0, 17.5, 34.0]))) 
        let res2 = binned_statistic(data, value, statistic_type: "mean", numbins: 2)
        @Assert(approxEqual(res2.statistic, vector([2.0, 2.0])))
        @Assert(approxEqual(res2.bin_edges, vector([1.0, 17.5, 34.0])))
        let res3 = binned_statistic(data, value, statistic_type: "min", numbins: 2)
        @Assert(approxEqual(res3.statistic, vector([1.0, 0.0])))
        @Assert(approxEqual(res3.bin_edges, vector([1.0, 17.5, 34.0])))
    }
}