/*
 * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
 * This source file is part of the Cangjie project, licensed under Apache-2.0
 * with Runtime Library Exception.
 *
 * See https://cangjie-lang.cn/pages/LICENSE for license information.
 */

// The Cangjie API is in Beta. For details on its capabilities and limitations, please refer to the README file.

package std.unittest

import std.math.abs

// float delta within interval START

private func isDeltaWithinInterval(delta: Float16) {
    if (delta >= 0.0f16 && delta < Float16.Max) {
        return (true, "")
    }
    return (false, "delta must be in interval [0.0, Float16.Max). Actual delta: ${delta}.")
}

private func isDeltaWithinInterval(delta: Float32) {
    if (delta >= 0.0f32 && delta < Float32.Max) {
        return (true, "")
    }
    return (false, "delta must be in interval [0.0, Float32.Max). Actual delta: ${delta}.")
}

private func isDeltaWithinInterval(delta: Float64) {
    if (delta >= 0.0f64 && delta < Float64.Max) {
        return (true, "")
    }
    return (false, "delta must be in interval [0.0, Float64.Max). Actual delta: ${delta}.")
}

// float delta within interval END

// is float delta NaN START

private let DELTA_IS_NAN_MESSAGE = "delta is NaN."

private func isDeltaNaN(delta: Float16) {
    if (delta.isNaN()) {
        return (true, DELTA_IS_NAN_MESSAGE)
    }
    return (false, "")
}

private func isDeltaNaN(delta: Float32) {
    if (delta.isNaN()) {
        return (true, DELTA_IS_NAN_MESSAGE)
    }
    return (false, "")
}

private func isDeltaNaN(delta: Float64) {
    if (delta.isNaN()) {
        return (true, DELTA_IS_NAN_MESSAGE)
    }
    return (false, "")
}

// is float delta NaN END

// is delta valid START

private func isDeltaValid(delta: Float16) {
    let (isDeltaNaN, nanErrorMessage) = isDeltaNaN(delta)
    if (isDeltaNaN) {
        return (false, nanErrorMessage)
    }

    let (isDeltaWithinInterval, intervalErrMessage) = isDeltaWithinInterval(delta)
    if (!isDeltaWithinInterval) {
        return (false, intervalErrMessage)
    }

    return (true, "")
}

private func isDeltaValid(delta: Float32) {
    let (isDeltaNaN, nanErrorMessage) = isDeltaNaN(delta)
    if (isDeltaNaN) {
        return (false, nanErrorMessage)
    }
    
    let (isDeltaWithinInterval, intervalErrMessage) = isDeltaWithinInterval(delta)
    if (!isDeltaWithinInterval) {
        return (false, intervalErrMessage)
    }

    return (true, "")
}

private func isDeltaValid(delta: Float64) {
    let (isDeltaNaN, nanErrorMessage) = isDeltaNaN(delta)
    if (isDeltaNaN) {
        return (false, nanErrorMessage)
    }
    
    let (isDeltaWithinInterval, intervalErrMessage) = isDeltaWithinInterval(delta)
    if (!isDeltaWithinInterval) {
        return (false, intervalErrMessage)
    }

    return (true, "")
}

// is delta valid END

// is relative delta valid START

private func isRelativeDeltaValid(delta: RelativeDelta<Float16>) {
    let (isAbsDeltaValid, absErrMessage) = isDeltaValid(delta.absolute)
    if (!isAbsDeltaValid) {
        return (false, "Absolute ${absErrMessage}")
    }

    let (isRelDeltaValid, relErrMessage) = isDeltaValid(delta.relative)
    if (!isRelDeltaValid) {
        return (false, "Relative ${relErrMessage}")
    }
    
    return (true, "")
}

private func isRelativeDeltaValid(delta: RelativeDelta<Float32>) {
    let (isAbsDeltaValid, absErrMessage) = isDeltaValid(delta.absolute)
    if (!isAbsDeltaValid) {
        return (false, "Absolute ${absErrMessage}")
    }

    let (isRelDeltaValid, relErrMessage) = isDeltaValid(delta.relative)
    if (!isRelDeltaValid) {
        return (false, "Relative ${relErrMessage}")
    }
    
    return (true, "")
}

private func isRelativeDeltaValid(delta: RelativeDelta<Float64>) {
    let (isAbsDeltaValid, absErrMessage) = isDeltaValid(delta.absolute)
    if (!isAbsDeltaValid) {
        return (false, "Absolute ${absErrMessage}")
    }

    let (isRelDeltaValid, relErrMessage) = isDeltaValid(delta.relative)
    if (!isRelDeltaValid) {
        return (false, "Relative ${relErrMessage}")
    }
    
    return (true, "")
}

// is relative delta valid END

private func validateFloats(f1: Any, f2: Any, delta: Any) {
    func isComparingArgValid(arg: Any, isFirstArg!: Bool): Bool {
        let paramNumMsg = if (isFirstArg) { "first" } else { "second" }
        match (arg) {
            case arg: Float16 => !arg.isNaN()
            case arg: Float32 => !arg.isNaN()
            case arg: Float64 => !arg.isNaN()
            case _ => throw IllegalStateException("Unexpected non float value in ${paramNumMsg} argument.")
        }
    }

    if (!isComparingArgValid(f1, isFirstArg: true)) {
        throw IllegalArgumentException("First argument is NaN.")
    }
    if (!isComparingArgValid(f2, isFirstArg: false)) {
        throw IllegalArgumentException("Second argument is NaN.")
    }

    let (isDeltaValid, deltaErrMessage) = match (delta) {
        case delta: Float16 => isDeltaValid(delta)
        case delta: Float32 => isDeltaValid(delta)
        case delta: Float64 => isDeltaValid(delta)
        case delta: RelativeDelta<Float16> => isRelativeDeltaValid(delta)
        case delta: RelativeDelta<Float32> => isRelativeDeltaValid(delta)
        case delta: RelativeDelta<Float64> => isRelativeDeltaValid(delta)
        case _ => throw IllegalStateException("Delta has invalid type.")
    }

    if (!isDeltaValid) {
        throw IllegalArgumentException(deltaErrMessage)
    }
}

extend Float16 <: NearEquatable<Float16, Float16> & NearEquatable<Float16, RelativeDelta<Float16>> {
    public func isNear(obj: Float16, delta!: Float16): Bool {
        validateFloats(this, obj, delta)
        if (this.isInf() || obj.isInf()) {
            return false;
        }
        abs(this - obj) <= delta
    }

    public func isNear(obj: Float16, delta!: RelativeDelta<Float16>): Bool {
        validateFloats(this, obj, delta)
        if (this.isInf() || obj.isInf()) {
            return false;
        }
        abs(this - obj) <= (delta.absolute + delta.relative * max(abs(this), abs(obj)))
    }
}

extend Float32 <: NearEquatable<Float32, Float32> & NearEquatable<Float32, RelativeDelta<Float32>> {
    public func isNear(obj: Float32, delta!: Float32): Bool {
        validateFloats(this, obj, delta)
        if (this.isInf() || obj.isInf()) {
            return false;
        }
        abs(this - obj) <= delta
    }

    public func isNear(obj: Float32, delta!: RelativeDelta<Float32>): Bool {
        validateFloats(this, obj, delta)
        if (this.isInf() || obj.isInf()) {
            return false;
        }
        abs(this - obj) <= (delta.absolute + delta.relative * max(abs(this), abs(obj)))
    }
}

extend Float64 <: NearEquatable<Float64, Float64> & NearEquatable<Float64, RelativeDelta<Float64>> {
    public func isNear(obj: Float64, delta!: Float64): Bool {
        validateFloats(this, obj, delta)
        if (this.isInf() || obj.isInf()) {
            return false;
        }
        abs(this - obj) <= delta
    }

    public func isNear(obj: Float64, delta!: RelativeDelta<Float64>): Bool {
        validateFloats(this, obj, delta)
        if (this.isInf() || obj.isInf()) {
            return false;
        }
        abs(this - obj) <= (delta.absolute + delta.relative * max(abs(this), abs(obj)))
    }
}