package scientific.stats.mcmc

import std.math.*

import scientific.linear.*
import scientific.matplot.*
import scientific.stats.normal.*
import scientific.stats.continuous.*
import scientific.stats.summary.*
import scientific.stats.random.Random

// foreign func malloc(size: UIntNative): CPointer<Unit>
// foreign func genmn(param: CPointer<Float64>): CPointer<Float64>

/*
 * multivariate normal distribution sample
 * param: 
 *  PARM(1) contains the size of the deviates, P
 *  PARM(2:P+1) contains the mean vector.
 *  PARM(P+2:P*(P+3)/2+1) contains the upper half of the Cholesky
 * std: upper half of the Cholesky
 */
// public func multiNormSample(n: Int64, mean: Tensor<Float64>, std: Tensor<Float64>): Tensor<Float64> {

//     print("${n}\n")
//     print("${mean[0]}\n")
//     print("${std[0]}\n")
//     let num = n * ( n + 3 ) / 2 + 1
//     var t1 = unsafe { malloc(UIntNative(8 * num)) }
//     var param = unsafe { CPointer<Float64>(t1) }

//     unsafe { param.write(Float64(n)) }
//     for(i in 0..n) {
//         unsafe { param.write(mean[i]) } 
//     }
//     for(i in 0..(num - n - 1)) {
//         unsafe { param.write(std[i]) } 
//     }

   

//     var t2 = unsafe { malloc(UIntNative(8 * n)) }
//     var samples = unsafe { CPointer<Float64>(t1) }
//     unsafe {samples = genmn(param)}

//     let res = Tensor<Float64>(n)
//     for(i in 0..n) {
//         res[i] = unsafe { samples.read(i) }
//     }

//     return res
// }

public func mylog(x: Float64): Float64 {
    return log(x)
}

public func logAdj(x: Float64): (Float64, (Float64) -> Float64) {
    return (
        log(x),
        { dy: Float64 => return dy / x }
    )
}

public func mypow(x: Float64): Float64 {
    return pow(Float64.getE(), x)
}

public func powAdj(x: Float64): (Float64, (Float64) -> Float64) {
    let temp = pow(Float64.getE(), x)
    return (
        temp,
        { dy: Float64 => return dy * temp }
    )
}

/*
public func readDataIris(path: String): Matrix<Float64> {

    var fs: FileStream = FileStream(path)

    let data = empty<Float64>(100, 5)
    if (fs.openFile()) {

        var toRead_1: Array<UInt8> = Array<UInt8>(27, { i => 0 })
        for(i in 0..50) {

            fs.seek(28 * i, SeekOrigin.BeginPos)
            for(j in 0..toRead_1.size()) {
                toRead_1[j] = 0
            }
            let num = fs.read(toRead_1)
            for(j in 0..4) {
                var temp = 0.0
                temp += Float64(toRead_1[j*4] - 48)
                temp += Float64(toRead_1[j*4 + 2] - 48) * 0.1
                // print("${temp}\n")
                data[i,j] = temp
            }
            data[i, 4] = 0.0

        }


        var toRead_2: Array<UInt8> = Array<UInt8>(31, { i => 0 })
        for(i in 0..50) {
            fs.seek(1400 + i * 32, SeekOrigin.BeginPos)
            // fs.seek(28 * i, SeekOrigin.BeginPos)
            for(j in 0..toRead_2.size()) {
                toRead_2[j] = 0
            }
            let num = fs.read(toRead_2)
            for(j in 0..4) {
                var temp = 0.0
                temp += Float64(toRead_2[j*4] - 48)
                temp += Float64(toRead_2[j*4 + 2] - 48) * 0.1
                // print("${temp}\n")
                data[i + 50, j] = temp
            }
            data[i + 50, 4] = 1.0
        }

        // var toRead_3: Array<UInt8> = Array<UInt8>(30, { i => 0 })
        // for(i in 0..50) {
        //     fs.seek(3000+ i * 31, SeekOrigin.BeginPos)
        //     for(j in 0..toRead_3.size()) {
        //         toRead_3[j] = 0
        //     }
        //     let num = fs.read(toRead_3)
        //     for(j in 0..4) {
        //         var temp = 0.0
        //         temp += Float64(toRead_3[j*4] - 48)
        //         temp += Float64(toRead_3[j*4 + 2] - 48) * 0.1
        //         print("${temp}\n")
        //         data[(i + 100, j)] = temp
        //     }
        //     data[(i + 100, 4)] = 2.0
        // }

        fs.flush()
    
    }

    fs.close()
    return data
}
*/

public func plotTrace(draws: Matrix<Float64>, path: String) {
    let trace: Matrix<Float64> = draws.transpose()
    let n = trace.getRows()
    for(i in 0..n) {

        let res = trace[i]
        subplot(1, 2, 0)
        hist(res)
        // title("${path}")

        subplot(1, 2, 1)
        plot(res)
        // title("${path}")
        save("./tests/imgs/mcmc/${path}/figure_${i}.svg", "svg")
        clear()
    }
}


//compute autocorrelation coefficient
public func acf(data: Vector<Float64>, k: Int64): Float64 {
    let mean = mean(data)
    let va = variance(data)
    let n = data.size()

    var res = 0.0

    for(i in 0..(n-k-1)) {
        res += (data[i] - mean) * (data[i + k] - mean)
    }

    return res / va / Float64(n)

}


//k为采样收敛时的采样次数
public func plotAcf(draws: Matrix<Float64>, k: Int64, path: String) {

    let trace: Matrix<Float64> = draws.transpose()
    let n = trace.getRows()
    for (i in 0..n) {
        let res = trace[i]

        var temp = 0.0
        var conv: Bool = false

        let acfTrace = empty<Float64>(k)
        for(j in 0..k){
            temp = acf(res, j)
            if(temp < 0.02 && conv == false) {
                print("${path}...ess...${ess(res, j)}\n")
                conv = true
            }
            acfTrace[j] = temp
        }
        
        plot(acfTrace)
        save("./tests/imgs/mcmc/${path}/acf_${i}.svg", "svg")
        clear()
    }

}


public func ess(data: Vector<Float64>, k: Int64): Float64 {
    let n = data.size()
    var res = 0.0

    for (i in 0..(k-1)) {
        res += (acf(data, i + 1))
    }

    return Float64(n) / (1.0 + 2.0 * res)
}

//compute maximum mean discrepancy
func cmmd(a: Vector<Float64>, b: Vector<Float64>, coreFun:(Float64,Float64) -> Float64): Float64 {
    let n1 = a.size()
    let n2 = b.size()

    var res  = 0.0
    var res1 = 0.0
    var res2 = 0.0
    var res3 = 0.0

    for(i in 0..n1) {
        for(j in 0..n1) {
            res1 += coreFun(a[i], a[j])
        }
    }

    for(i in 0..n1) {
        for(j in 0..n2) {
            res2 += coreFun(a[i], b[j])
        }
    }

    for(i in 0..n2) {
        for(j in 0..n2) {
            res3 += coreFun(b[i], b[j])
        }
    }

    res = res1 / (Float64(n1) * Float64(n1)) - 2.0 * res2 / (Float64(n1) * Float64(n2)) + res3 / (Float64(n2) * Float64(n2))

    return res
}

func coreFun_1(x: Float64, y: Float64): Float64 {
    // return exp( - abs(x - y) / 2.0 ) //fix_sigma = 2.0
    return pow((1.0 + x * y), 2.0)
}

public func mmd(a: Vector<Float64>, b: Vector<Float64>): Float64 {
    return cmmd(a, b, coreFun_1)
}

public func plotMmd_linear(draws: Matrix<Float64>, beta_0: Float64, beta_1: Float64, tau: Float64, path: String, r: Random) {

    let trace: Matrix<Float64> = draws.transpose()
    let beta_0_s = mean(trace[0])   //mean作为推测结果
    let beta_1_s = mean(trace[1])
    let tau_s    = mean(trace[2])

    let num = 1000

    let res = empty<Float64>(num / 100)

    var sample_s = empty<Float64>(num)

    var data_comp = empty<Float64>(num, 2)
    var data_sample = empty<Float64>(num)

    for (i in 0..num) {
        data_comp[i,0] = uniformSample(r, 0.0, 4.0)
        data_comp[i,1] = normalSampleFloat64(r, beta_0 + beta_1 * data_comp[i,0], tau)   
        data_sample[i] = normalSampleFloat64(r, beta_0_s + beta_1_s * data_comp[i,0], tau_s)
    }

    for(i in 0..(num / 100)){
        var temp1 = empty<Float64>((i + 1) * 100)
        var temp2 = empty<Float64>((i + 1) * 100)
        for(j in 0..((i + 1)*100)) {
            temp1[j] = data_comp[j,1]
            temp2[j] = data_sample[j]
        }
        res[i] = mmd(temp1, temp2)
    }  
    
    plot(res)
    save("./tests/imgs/mcmc/${path}/mmd.svg", "svg")
    clear()

}