package scientific.stats.random

import std.time.*
import std.unittest.*
import std.unittest.testmacro.*

import scientific.utils.assertEqual
import scientific.linear.empty
import scientific.linear.vector
import scientific.linear.approxEqual

let RK_STATE_LEN = 624
foreign func malloc(size: UIntNative): CPointer<Unit>

@C
public struct mt19937_state {
    var key: CPointer<UInt32>
    var pos: Int32

    init() {
        this.pos = 0
        var p1 = unsafe { malloc(UIntNative(4 * RK_STATE_LEN))}
        this.key = unsafe { CPointer<UInt32>(p1) }
    }
}


foreign func mt19937_seed(state: CPointer<mt19937_state>, s: UInt32): Unit
foreign func mt19937_next64(state: CPointer<mt19937_state>): UInt64
foreign func mt19937_next32(state: CPointer<mt19937_state>): UInt32
foreign func mt19937_next_double(state: CPointer<mt19937_state>): Float64


@C
public struct Random {
    var pc: CPointer<mt19937_state>

    public init() {
        var p2 = unsafe { malloc(UIntNative(4 * (RK_STATE_LEN + 1))) }
        this.pc = CPointer<mt19937_state>(p2)
        this.setSeed()
    }

    public init(s: UInt32) {
        var p2 = unsafe { malloc(UIntNative(4 * (RK_STATE_LEN + 1))) }
        this.pc = CPointer<mt19937_state>(p2)
        this.setSeed(s)
    }

    public func setSeed(s: UInt32): Unit {
        unsafe { mt19937_seed(pc, s) }
    }

    public func setSeed(): Unit {
        var t = DateTime.now()
        unsafe { mt19937_seed(pc, UInt32(t.nanosecond % 10000)) }
    }

    public func nextUInt64(): UInt64 {
        return unsafe { mt19937_next64(pc) }
    }

    public func nextUInt32(): UInt32 {
        return unsafe { mt19937_next32(pc) }
    }

    public func nextFloat64(): Float64 {
        return unsafe { mt19937_next_double(pc) }
    }

    public func nextFloat32(): Float32 {
        return Float32(nextUInt32()) / 4294967295.0
    }
}


@Test
public class TestRandom {
    @TestCase
    func testRandom1(): Unit {
        /*
        * import numpy as np
        * np.random.seed(2)
        * print(np.random.rand(10))
        */
        let m: Random = Random(2)
        let a = empty<Float64>(10)
        for (i in 0..10) {
            a[i] = m.nextFloat64()
        }
        @Assert(approxEqual(a, vector<Float64>([
            0.4359949, 0.02592623, 0.54966248, 0.43532239,
            0.4203678, 0.33033482, 0.20464863, 0.61927097,
            0.29965467, 0.26682728]
        ), atol:1e-7))
    }

    @TestCase
    func testRandom2(): Unit {
        let m: Random = Random(2)
        let a = empty<UInt64>(10)
        for (i in 0..10) {
            a[i] = m.nextUInt64()
        }
        @Assert(a == vector<UInt64>([
            8042686386972756495, 478254495130285640, 10139483024654235947,
            8030280778952246347, 7754417253617937671, 6093601889722073521,
            3775101017224948299, 11423533102773113903, 5527653144310611231,
            4922094483131799956
        ]))
        // print(a)  // opening print statement results in core dump
    }

    @TestCase
    func testRandom3(): Unit {
        let m: Random = Random(2)
        let a = empty<UInt32>(10)
        for (i in 0..10) {
            a[i] = m.nextUInt32()
        }
        @Assert(a == vector<UInt32>([
            1872583848, 794921487, 111352301, 4000937544, 2360782358,
            4070471979, 1869695442, 2081981515, 1805465960, 1376693511
        ]))
        // print(a)  // opening print statement results in core dump
    }
}