package scientific.numbers

import std.unittest.*
import std.unittest.testmacro.*

foreign func c_writeComplex64(ptr: CPointer<Unit>, real: Float64, imag: Float64, offset: UIntNative): Unit
foreign func c_getComplex64Real(ptr: CPointer<Unit>, offset: UIntNative): Float64
foreign func c_getComplex64Imag(ptr: CPointer<Unit>, offset: UIntNative): Float64
foreign func c_writeComplex32(ptr: CPointer<Unit>, real: Float32, imag: Float32, offset: UIntNative): Unit
foreign func c_getComplex32Real(ptr: CPointer<Unit>, offset: UIntNative): Float32
foreign func c_getComplex32Imag(ptr: CPointer<Unit>, offset: UIntNative): Float32

public interface Complex<T> <: CNumber<T> where T <: Equatable<T> {
    func conjugate(): T
    func approxEqual(right: T, atol: Float64): Bool
}

public func approxEqualC<T>(a: T, b: T, atol!:Float64 = 1e-7): Bool where T <: Complex<T> {
    return a.approxEqual(b, atol)
}

public func assertApproxEqualC<T>(a: T, b: T, atol!:Float64 = 1e-7): Unit where T <: Complex<T> & ToString {
    if (approxEqualC<T>(a, b, atol:atol)) {
        return ()
    } else {
        print("a = ${a}, b = ${b}\n")
        throw AssertException()
    }
}

public struct Complex64 {
    public var real: Float64
    public var imag: Float64

    public init(real: Float64, imag: Float64) {
        this.real = real
        this.imag = imag
    }

    public func normsq(): Float64 {
        return this.real * this.real + this.imag * this.imag
    }

    public func norm(): Float64 {
        return sqrt(this.normsq())
    }

    public static func read(ptr: CPointer<Unit>, offset: Int64): Complex64 {
        let real = unsafe { c_getComplex64Real(ptr, UIntNative(offset)) }
        let imag = unsafe { c_getComplex64Imag(ptr, UIntNative(offset)) }
        return Complex64(real, imag)
    }

    public static func write(ptr: CPointer<Unit>, val: Complex64, offset: Int64): Unit {
        unsafe { c_writeComplex64(ptr, val.real, val.imag, UIntNative(offset)) }
    }
}

extend Complex64 <: Equatable<Complex64> {
    public operator func ==(x: Complex64): Bool {
        return this.real == x.real && this.imag == x.imag
    }

    public operator func !=(x: Complex64): Bool {
        return this.real != x.real || this.imag != x.imag
    }
}

extend Complex64 <: Complex<Complex64> {
    public static func fromInt(n: Int64): Complex64 {
        return Complex64(Float64(n), 0.0)
    }

    public static func getSize(): Int64 {
        return 16
    }

    public static func getType(): String {
        return "Complex64"
    }

    public operator func +(x: Complex64): Complex64 {
        return Complex64(this.real + x.real, this.imag + x.imag)
    }

    public operator func -(x: Complex64): Complex64 {
        return Complex64(this.real - x.real, this.imag - x.imag)
    }

    public operator func *(x: Complex64): Complex64 {
        return Complex64(this.real * x.real - this.imag * x.imag,
                         this.real * x.imag + this.imag * x.real)
    }

    public operator func /(x: Complex64): Complex64 {
        let denom = x.real * x.real + x.imag * x.imag
        return Complex64((this.real * x.real + this.imag * x.imag) / denom,
                         (this.imag * x.real - this.real * x.imag) / denom)
    }

    public operator func -(): Complex64 {
        return Complex64(-this.real, -this.imag)
    }

    public func conjugate(): Complex64 {
        return Complex64(this.real, -this.imag)
    }

    public func approxEqual(right: Complex64, atol: Float64) {
        return abs(this.real - right.real) < atol && abs(this.imag - right.imag) < atol
    }
}

extend Complex64 <: ToString {
    public func toString(): String {
        if (this.imag < 0.0) {
            return this.real.toString() + "-" + (-this.imag).toString() + "i"
        } else {
            return this.real.toString() + "+" + this.imag.toString() + "i"
        }
    }
}

public struct Complex32 {
    public var real: Float32
    public var imag: Float32

    public init(real: Float32, imag: Float32) {
        this.real = real
        this.imag = imag
    }

    public func normsq(): Float32 {
        return this.real * this.real + this.imag * this.imag
    }

    public func norm(): Float32 {
        return sqrt(this.normsq())
    }

    public static func read(ptr: CPointer<Unit>, offset: Int64): Complex32 {
        let real = unsafe { c_getComplex32Real(ptr, UIntNative(offset)) }
        let imag = unsafe { c_getComplex32Imag(ptr, UIntNative(offset)) }
        return Complex32(real, imag)
    }

    public static func write(ptr: CPointer<Unit>, val: Complex32, offset: Int64): Unit {
        unsafe { c_writeComplex32(ptr, val.real, val.imag, UIntNative(offset)) }
    }
}

extend Complex32 <: Equatable<Complex32> {
    public operator func ==(x: Complex32): Bool {
        return this.real == x.real && this.imag == x.imag
    }

    public operator func !=(x: Complex32): Bool {
        return this.real != x.real || this.imag != x.imag
    }
}

extend Complex32 <: Complex<Complex32> {
    public static func fromInt(n: Int64): Complex32 {
        return Complex32(Float32(n), Float32(0.0))
    }

    public static func getSize(): Int64 {
        return 8
    }

    public static func getType(): String {
        return "Complex32"
    }

    public operator func +(x: Complex32): Complex32 {
        return Complex32(this.real + x.real, this.imag + x.imag)
    }

    public operator func -(x: Complex32): Complex32 {
        return Complex32(this.real - x.real, this.imag - x.imag)
    }

    public operator func *(x: Complex32): Complex32 {
        return Complex32(this.real * x.real - this.imag * x.imag,
                         this.real * x.imag + this.imag * x.real)
    }

    public operator func /(x: Complex32): Complex32 {
        let denom = x.real * x.real + x.imag * x.imag
        return Complex32((this.real * x.real + this.imag * x.imag) / denom,
                         (this.imag * x.real - this.real * x.imag) / denom)
    }

    public operator func -(): Complex32 {
        return Complex32(-this.real, -this.imag)
    }

    public func conjugate(): Complex32 {
        return Complex32(this.real, -this.imag)
    }

    public func approxEqual(right: Complex32, atol: Float64) {
        return abs(this.real - right.real).toFloat64() < atol &&
               abs(this.imag - right.imag).toFloat64() < atol
    }
}

extend Complex32 <: ToString {
    public func toString(): String {
        if (this.imag < 0.0) {
            return this.real.toString() + "-" + (-this.imag).toString() + "i"
        } else {
            return this.real.toString() + "+" + this.imag.toString() + "i"
        }
    }
}


@Test
public class TestComplex {
    @TestCase
    func testComplex64(): Unit {
        let x = Complex64(2.0, 1.0)
        @Assert(approxEqual(x.normsq(), 5.0, atol:1e-10))
        @Assert(approxEqual(x.norm(), 2.23606797749979, atol:1e-10))

        let y = Complex64(1.0, 1.0)
        @Assert(approxEqualC(x + y, Complex64(3.0, 2.0), atol:1e-10))
        @Assert(approxEqualC(x - y, Complex64(1.0, 0.0), atol:1e-10))
        @Assert(approxEqualC(x * y, Complex64(1.0, 3.0), atol:1e-10))
        @Assert(approxEqualC(x.conjugate(), Complex64(2.0, -1.0), atol:1e-10))
    }

    @TestCase
    func testComplex32(): Unit {
        let x = Complex32(2.0, 1.0)
        @Assert(approxEqual(x.normsq(), Float32(5.0), atol:1e-5))
        @Assert(approxEqual(x.norm(), Float32(2.236068), atol:1e-5))

        let y = Complex32(1.0, 1.0)
        @Assert(approxEqualC(x + y, Complex32(3.0, 2.0), atol:1e-5))
        @Assert(approxEqualC(x - y, Complex32(1.0, 0.0), atol:1e-5))
        @Assert(approxEqualC(x * y, Complex32(1.0, 3.0), atol:1e-5))
        @Assert(approxEqualC(x.conjugate(), Complex32(2.0, -1.0), atol:1e-5))
    }
}