/*
 * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
 * This source file is part of the Cangjie project, licensed under Apache-2.0
 * with Runtime Library Exception.
 *
 * See https://cangjie-lang.cn/pages/LICENSE for license information.
 */

// The Cangjie API is in Beta. For details on its capabilities and limitations, please refer to the README file.

package std.unittest.diff

import std.convert.Formattable
import std.math.ceil
import std.unittest.common.PrettyPrinter

interface AssertPrintableFloat<F> where F <: AnyFloat & Formattable & ToString & Equatable<F> {
    /*
     * To print the shortest form of float number, we use r'g' specifier which also removes all trailing zeros,
     * there is no way to configure this behaviour specifically,
     * but we need to keep the fractional part even it's empty to highlight that it's a floating-point number,
     * so for empty fractional part, we add ".0" explicitly.
     */
    static func addTrailingPointZeroIfNeeded(floatStr: String) {
        if (!floatStr.contains(".") && !floatStr.contains("e")) {
            floatStr + ".0"
        } else {
            floatStr
        }
    }

    static func pprint(left: F, right: F, delta: F, pp: PrettyPrinter, leftPrefix: String, rightPrefix: String,
        level: Int64): PrettyPrinter {
        if (level == 1) {
            pp.pprintNotEqual(leftPrefix, rightPrefix).newLine()
        }

        var sigDigits = 6
        let isLeftSpecialValue = left.isNaN() || left.isInf()
        let isRightSpecialValue = right.isNaN() || right.isInf()

        if (isLeftSpecialValue || isRightSpecialValue) {
            let leftStr = if (isLeftSpecialValue) { left.toString() } else {
                addTrailingPointZeroIfNeeded(left.format(".${sigDigits}g"))
            }
            let rightStr = if (isRightSpecialValue) { right.toString() } else {
                addTrailingPointZeroIfNeeded(right.format(".${sigDigits}g"))
            }
            return pp.pprintFloatDiff(leftStr, rightStr, Option<String>.None, leftPrefix, rightPrefix)
        }

        var leftStr = left.format(".${sigDigits}g")
        var rightStr = right.format(".${sigDigits}g")
        var deltaStr = delta.format(".${sigDigits}g")

        while (leftStr == rightStr && deltaStr != "." && sigDigits <= 50) {
            sigDigits = Int64(ceil(Float64(sigDigits) * 3.0 / 2.0))
            leftStr = left.format(".${sigDigits}g")
            rightStr = right.format(".${sigDigits}g")
            deltaStr = delta.format(".${sigDigits}g")
        }

        return pp.pprintFloatDiff(
            addTrailingPointZeroIfNeeded(leftStr),
            addTrailingPointZeroIfNeeded(rightStr),
            addTrailingPointZeroIfNeeded(deltaStr),
            leftPrefix,
            rightPrefix)
    }
}

extend<T> Array<T> where T <: Comparable<T> {
    func max() {
        var result = this[0]
        for (e in this) {
            if (e > result) {
                result = e
            }
        }
        result
    }
}

extend PrettyPrinter {
    func pprintFloatDiff(leftStr: String, rightStr: String, deltaStrOpt: Option<String>, leftPrefix: String,
        rightPrefix: String): PrettyPrinter {
        let leftRightSize = UInt64([leftPrefix.size, rightPrefix.size, 4].max()) + 1
        let deltaStr = deltaStrOpt ?? return appendRightAligned(leftPrefix, leftRightSize)
            .append(": ")
            .colored(RED, leftStr)
            .newLine()
            .appendRightAligned(rightPrefix, leftRightSize)
            .append(": ")
            .colored(RED, rightStr)
            .newLine()

        let (commonPrefixLen, _) = OneLineDiffBuilder.commonPrefix(leftStr, rightStr)
        let size = max(leftRightSize, UInt64("delta".size) + 1)

        return appendRightAligned(leftPrefix, size)
            .append(": ")
            .append(leftStr[..commonPrefixLen])
            .colored(RED, leftStr[commonPrefixLen..])
            .newLine()
            .appendRightAligned(rightPrefix, size)
            .append(": ")
            .append(rightStr[..commonPrefixLen])
            .colored(RED, rightStr[commonPrefixLen..])
            .newLine()
            .appendRightAligned("delta", size)
            .append(": ")
            .colored(YELLOW, deltaStr)
            .newLine()
    }
}

// Below are temporary hacks to workaround a compiler issue
// related to produced wrong symbols for floats (further, linker failed)
// in case of extending them by `AssertPrintableFloat` directly

interface AnyFloat {
    func isNaN(): Bool
    func isInf(): Bool
}

extend Float16 <: AnyFloat {}

extend Float32 <: AnyFloat {}

extend Float64 <: AnyFloat {}

extend Float16 <: AssertPrintable<Float16> {
    public prop hasNestedDiff: Bool {
        get() { false }
    }
    public func pprintForAssertion(pp: PrettyPrinter, right: Float16, leftPrefix: String, rightPrefix: String,
        level: Int64): PrettyPrinter {
        return AssertPrintableFloat<Float16>.pprint(this, right, this - right, pp, leftPrefix, rightPrefix, level)
    }
}

extend Float32 <: AssertPrintable<Float32> {
    public prop hasNestedDiff: Bool {
        get() { false }
    }
    public func pprintForAssertion(pp: PrettyPrinter, right: Float32, leftPrefix: String, rightPrefix: String,
        level: Int64): PrettyPrinter {
        return AssertPrintableFloat<Float32>.pprint(this, right, this - right, pp, leftPrefix, rightPrefix, level)
    }
}

extend Float64 <: AssertPrintable<Float64> {
    public prop hasNestedDiff: Bool {
        get() { false }
    }
    public func pprintForAssertion(pp: PrettyPrinter, right: Float64, leftPrefix: String, rightPrefix: String,
        level: Int64): PrettyPrinter {
        return AssertPrintableFloat<Float64>.pprint(this, right, this - right, pp, leftPrefix, rightPrefix, level)
    }
}