/*
 * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
 * This source file is part of the Cangjie project, licensed under Apache-2.0
 * with Runtime Library Exception.
 *
 * See https://cangjie-lang.cn/pages/LICENSE for license information.
 */

package stdx.net.http

import stdx.encoding.base64.toBase64String
import stdx.encoding.url.{URL, Form}
import std.collection.*
import std.convert.Parsable
import std.fs.*
import std.io.*
import stdx.log.*
import std.process.*
import std.net.*
import std.time.*
import std.sync.*

const INT64_TO_HEX: String = "0123456789abcdef"
const MASK_I64: Int64 = 0x0f
const DEFAULT_LOG_LENTH: Int64 = 256
let shortDayOfWeekForm: Array<Str> = [Str("Sun"), Str("Mon"), Str("Tue"), Str("Wed"), Str("Thu"), Str("Fri"), Str("Sat")]
let shortMonthForm: Array<Str> = [Str("Jan"), Str("Feb"), Str("Mar"), Str("Apr"), Str("May"), Str("Jun"), Str("Jul"),
    Str("Aug"), Str("Sep"), Str("Oct"), Str("Nov"), Str("Dec")]
let stringTable: HashMap<Str, String> = HashMap<Str, String>(
    [
        (Str("host"), "host"),
        (Str("connection"), "connection"),
        (Str("upgrade-insecure-requests"), "upgrade-insecure-requests"),
        (Str("user-agent"), "user-agent"),
        (Str("accept"), "accept"),
        (Str("accept-encoding"), "accept-encoding"),
        (Str("accept-language"), "accept-language"),
        (Str("authority"), "authority"),
        (Str("method"), "method"),
        (Str("path"), "path"),
        (Str("scheme"), "scheme"),
        (Str("status"), "status"),
        (Str("accept-charset"), "accept-charset"),
        (Str("accept-ranges"), "accept-ranges"),
        (Str("access-control-allow-origin"), "access-control-allow-origin"),
        (Str("age"), "age"),
        (Str("authorization"), "authorization"),
        (Str("cache-control"), "cache-control"),
        (Str("content-disposition"), "content-disposition"),
        (Str("content-encoding"), "content-encoding"),
        (Str("content-language"), "content-language"),
        (Str("content-length"), "content-length"),
        (Str("content-location"), "content-location"),
        (Str("content-range"), "content-range"),
        (Str("content-type"), "content-type"),
        (Str("cookie"), "cookie"),
        (Str("date"), "date"),
        (Str("etag"), "etag"),
        (Str("expect"), "expect"),
        (Str("expires"), "expires"),
        (Str("from"), "from"),
        (Str("if-match"), "if-match"),
        (Str("if-modified-since"), "if-modified-since"),
        (Str("if-none-match"), "if-none-match"),
        (Str("if-range"), "if-range"),
        (Str("if-unmodified-since"), "if-unmodified-since"),
        (Str("last-modified"), "last-modified"),
        (Str("link"), "link"),
        (Str("location"), "location"),
        (Str("max-forwards"), "max-forwards"),
        (Str("proxy-authenticate"), "proxy-authenticate"),
        (Str("proxy-authorization"), "proxy-authorization"),
        (Str("range"), "range"),
        (Str("referer"), "referer"),
        (Str("refresh"), "refresh"),
        (Str("retry-after"), "retry-after"),
        (Str("server"), "server"),
        (Str("set-cookie"), "set-cookie"),
        (Str("strict-transport-security"), "strict-transport-security"),
        (Str("transfer-encoding"), "transfer-encoding"),
        (Str("vary"), "vary"),
        (Str("via"), "via"),
        (Str("www-authenticate"), "www-authenticate")
    ]
)

func getString(str: Str): String {
    if (let Some(object) <- stringTable.get(str)) {
        return object
    }
    return str.toString()
}

func toLowerCaseStr(str: String): Str {
    if (!str.hasUpper()) {
        return Str(str)
    }
    if (let Some(v) <- stringTable.get(Str(str))) {
        return Str(v)
    } else {
        return Str(str.toAsciiLower())
    }
}

extend Int64 {
    func toHexString(): String {
        var num: Int64 = this
        var arrSize: Int64 = 0
        for (_ in 0..16) {
            if (num > 0) {
                arrSize++
                num = num >> 4
            } else {
                break
            }
        }

        let arr = Array<Byte>(arrSize, repeat: b'0')
        num = this
        for (i in arrSize - 1..-1 : -1) {
            if (num > 0) {
                let item = num & MASK_I64
                arr[i] = INT64_TO_HEX[Int64(item)]
                num = num >> 4
            } else {
                break
            }
        }
        return String.fromUtf8(arr)
    }

    static func fromHexStr(hexRaw: Str): Int64 {
        if (hexRaw.isEmpty()) {
            throw IllegalArgumentException("Invalid hex string.")
        }
        if (hexRaw.size > 16 || (hexRaw.size == 16 &&
            (fromHexByte(hexRaw[0]) ?? throw IllegalArgumentException("Invalid hex string.")) > 7)) { // cjlint-ignore !G.EXP.03
            throw IllegalArgumentException("The hex number is too big.")
        }
        var number: Int64 = 0
        for (i in 0..hexRaw.size) {
            match (fromHexByte(hexRaw[i])) {
                case Some(v) =>
                    number <<= 4
                    number |= v
                case None => throw IllegalArgumentException("Invalid hex string.")
            }
        }
        return number
    }

    static func fromHexByte(h: Byte): ?Int64 {
        return match {
            case b'0' <= h && h <= b'9' => Int64(h - b'0')
            case b'a' <= h && h <= b'f' => Int64(h - b'a' + 10)
            case b'A' <= h && h <= b'F' => Int64(h - b'A' + 10)
            case _ => None
        }
    }
}

/* if status code of response is 3xx besides 304, client need do redirect */
let needRedirect = {code: UInt16 => code >= 300 && code < 400 && code != HttpStatusCode.STATUS_NOT_MODIFIED}

/* remove . .. // in path */
func canonicalPath(path: String): String {
    if (path.isEmpty() || path == "/") {
        return path
    }
    let raw = unsafe { path.rawData() }
    let newPath = Array<Byte>(raw.size + 1, repeat: b'/')
    var partLeft = 0
    var partRight = 0
    var writeIdx = 1
    while (partRight < raw.size) {
        let b = raw[partRight]
        if (b != b'/') {
            partRight++
            continue
        }
        match (path[partLeft..partRight]) {
            case "" | "." => ()
            case ".." => writeIdx = findPrevSlash(newPath, writeIdx) + 1
            case _ =>
                raw.copyTo(newPath, partLeft, writeIdx, partRight - partLeft)
                writeIdx += partRight - partLeft
                newPath[writeIdx] = b'/'
                writeIdx++
        }
        partRight++
        partLeft = partRight
    }
    if (partLeft < raw.size) {
        raw.copyTo(newPath, partLeft, writeIdx, partRight - partLeft)
        writeIdx += partRight - partLeft
    }
    return unsafe { String.fromUtf8Unchecked(newPath[..writeIdx]) }
}

func findPrevSlash(path: Array<Byte>, curIdx: Int64): Int64 {
    // only start slash in path
    if (curIdx == 1) {
        return 0
    }
    // get back to prev slash
    var prev = curIdx - 2
    while (prev >= 0) {
        if (path[prev] == b'/') {
            break
        }
        prev--
    }
    return prev
}

extend String {
    /**
     * To determine whether values contains target.
     * Values is caseInsensitive, target should consist of lower cases.
     */
    func caseInsensitiveMatchOne(values: ArrayList<String>) {
        for (value in values) {
            if (value.toAsciiLower() == this) {
                return true
            }
        }
        return false
    }

    func caseInsensitiveMatchAll(values: ArrayList<String>) {
        for (value in values) {
            if (value.toAsciiLower() != this) {
                return false
            }
        }
        return true
    }

    func matchAll(values: ArrayList<String>) {
        for (value in values) {
            if (value != this) {
                return false
            }
        }
        return true
    }

    func ifEmpty(s: String): String {
        return if (this.isEmpty()) {
            s
        } else {
            this
        }
    }

    func hasUpper(): Bool {
        return Str(this).hasUpper()
    }
}

extend<K, V> HashMapIterator<K, V> {
    func removeIf(condition: () -> Bool): Unit {
        if (condition()) {
            this.remove()
        }
    }
}

extend<K, V> HashMap<K, V> {
    func removeKeyIf(key: K, condition: () -> Bool): Unit {
        if (condition()) {
            this.remove(key)
        }
    }
}

/* map file extension to content-type */
let EXT_TYPE_MAP = HashMap<String, String>(
    [
        (".avif", "image/avif"),
        (".css", "text/css; charset=utf-8"),
        (".gif", "image/gif"),
        (".htm", "text/html; charset=utf-8"),
        (".html", "text/html; charset=utf-8"),
        (".jpeg", "image/jpeg"),
        (".jpg", "image/jpeg"),
        (".js", "text/javascript; charset=utf-8"),
        (".json", "application/json"),
        (".mjs", "text/javascript; charset=utf-8"),
        (".pdf", "application/pdf"),
        (".png", "image/png"),
        (".svg", "image/svg+xml"),
        (".wasm", "application/wasm"),
        (".webp", "image/webp"),
        (".xml", "text/xml; charset=utf-8")
    ]
)

/* make sure duration >= zero */
func checkDuration(d: Duration): Duration {
    if (d < Duration.Zero) {
        return Duration.Zero
    } else {
        return d
    }
}

func basicAuth(username: String, password: String): String {
    let auth = unsafe { "${username}:${password}".rawData() }
    return toBase64String(auth)
}

func parseProxyAuth(url: URL): ?String {
    var auth = None<String>
    if (!url.rawUserInfo.username().isEmpty()) {
        auth = "Basic ${basicAuth(url.rawUserInfo.username(), url.rawUserInfo.password() ?? "")}"
    }
    return auth
}

// rules of boundary  rfc2046#section-5.1.1
func checkBoundary(boundary: String): Bool {
    if (boundary.size < 1 || boundary.size > 70) {
        return false
    }
    if (boundary.endsWith(" ")) {
        return false
    }
    for (c in boundary) {
        if (c.isAsciiNumberOrLetter()) {
            continue
        }
        match (c) {
            case '\'' | '(' | ')' | '+' | '_' | ',' | '-' | '.' | '/' | ':' | '=' | '?' => continue
            case _ => return false
        }
    }
    return true
}

func sizeOf(inputStream: InputStream): ?Int64 {
    match (inputStream) {
        case s: Seekable =>
            if (s.length >= 0) {
                return s.length
            }
            return None
        case _ => return None
    }
}

/*
 * To support some special requirements:
 * 1. avoid creating and running a Timer if start is set to Duration.Max
 * 2. avoid initializing a Timer using zeroValue or Option in several individual classes
 * and future extensions
 */
class HttpTimer {
    var timer: ?Timer = None

    static let empty: HttpTimer = HttpTimer()

    private init() {}

    init(start!: Duration, task!: () -> Unit) {
        if (start != Duration.Max) {
            timer = Timer.once(start, task)
        }
    }

    func cancel(): Unit {
        if (let Some(t) <- timer) {
            t.cancel()
            timer = None
        }
    }
}

func getRFC1123String(dateArr: Array<Byte>): Str {
    let t = DateTime.nowUTC()
    writeTwoNum(dateArr, t.second, 41)
    writeTwoNum(dateArr, t.minute, 38)
    writeTwoNum(dateArr, t.hour, 35)
    let i = writeYear(dateArr, t.year)
    dateArr[i] = b' '
    shortMonthForm[t.month.toInteger() - 1].raw.copyTo(dateArr, 0, i - 3, 3)
    dateArr[i - 4] = b' '
    writeTwoNum(dateArr, t.dayOfMonth, i - 6)
    dateArr[i - 7] = b' '
    dateArr[i - 8] = b','
    shortDayOfWeekForm[t.dayOfWeek.toInteger()].raw.copyTo(dateArr, 0, i - 11, 3)
    return Str(dateArr[i - 11..])
}

func writeTwoNum(dateArr: Array<Byte>, n: Int64, idx: Int64): Unit {
    if (n < 0 || n > 99) {
        throw IllegalArgumentException("Invalid date number: ${n}.")
    }
    dateArr[idx] = UInt8(n / 10) + b'0'
    dateArr[idx + 1] = UInt8(n % 10) + b'0'
}

func writeYear(dateArr: Array<Byte>, year: Int64): Int64 {
    var y = UInt64(year)
    var i = 33
    while (y > 0) {
        dateArr[i] = UInt8(y % 10) + b'0'
        y /= 10
        i--
    }
    return i
}

func parseAndCheckResponseLine(line: Str): (String, UInt16) {
    let version = line.splitFirst(WS)?.toString() ?? throw HttpException("Invalid response status line.")
    if (version != "HTTP/1.0" && version != "HTTP/1.1") {
        throw HttpException("Invalid response status line.")
    }

    /**
     * The status code start from index 9.
     *
     * 0123456789abcde
     * ---------------
     * HTTP/1.1 200 OK
     */
    let status = line.splitFirst(WS, 9)?.toString() ?? throw HttpException("Invalid response status line.")
    let statusCode = match (UInt16.tryParse(status)) {
        case Some(v) =>
            if (v < 100 || v > 599) {
                throw HttpException("Invalid response status code.")
            }
            v
        case None => throw HttpException("Invalid response status line.")
    }
    return (version, statusCode)
}

func parseAndCheckRequestLine(line: Str): (String, String, String) {
    // parse and check method
    let method = line.splitFirst(WS) ?? throw HttpStatusException(HttpStatusCode.STATUS_BAD_REQUEST, // cjlint-ignore !G.EXP.03
        "Invalid request method, request line: ${line}.")
    if (method.isEmpty() || !method.byteMatches(isTokenByte)) { // cjlint-ignore !G.EXP.03
        throw HttpStatusException(HttpStatusCode.STATUS_METHOD_NOT_ALLOWED, "Invalid request method: ${method}.")
    }
    if (method.size > MAX_METHOD_SIZE) {
        throw HttpStatusException(HttpStatusCode.STATUS_BAD_REQUEST, "Invalid method: ${method}.")
    }

    // parse target
    var i = method.size + 1
    let target = line.splitFirst(WS, i) ?? throw HttpStatusException(HttpStatusCode.STATUS_BAD_REQUEST, // cjlint-ignore !G.EXP.03
        "Invalid request target, request line: ${line}.")

    // parse and check version
    i += target.size + 1
    let version = line.splitFirst(WS, i) ?? throw HttpStatusException(HttpStatusCode.STATUS_BAD_REQUEST, // cjlint-ignore !G.EXP.03
        "Invalid request version, request line: ${line}.")
    if (version != Str("HTTP/1.1") && version != Str("HTTP/1.0")) { // cjlint-ignore !G.EXP.03
        throw HttpStatusException(HttpStatusCode.STATUS_HTTP_VERSION_NOT_SUPPORTED, "Incorrect version: ${version}.")
    }

    if (i + version.size != line.size) {
        throw HttpStatusException(HttpStatusCode.STATUS_BAD_REQUEST,
            "Invalid request line, more than three elements: ${line}.")
    }
    return (method.toString(), target.toString(), version.toString())
}

func parseAndCheckHeaderLine(line: Str, isReq: Bool): (String, String) {
    // parse and check name
    let name = line.splitFirst(SYMBOL_COLON) ?? throw chooseException("Invalid field line: ${line}.", isReq) // cjlint-ignore !G.EXP.03
    if (!name.byteMatches(isTokenByte)) {
        throw chooseException("Invalid field line: ${line}.", isReq)
    }
    if (name.size == line.size) {
        throw chooseException("Invalid field line: ${line}.", isReq)
    }
    // parse and check value
    let value = line.slice(name.size + 1).trim()
    if (!value.byteMatches(checkValueBytes)) {
        throw chooseException("Invalid field line: ${line}.", isReq)
    }
    return (getString(name), value.toString())
}

func checkValueBytes(b: Byte): Bool {
    return !(b < 0x20 && b != 0x09 || b == 0x7f || b > 0xff)
}

func chooseException(msg: String, isReq: Bool): Exception {
    return if (isReq) {
        HttpStatusException(HttpStatusCode.STATUS_BAD_REQUEST, msg)
    } else {
        HttpException(msg)
    }
}

func checkClAndTe(headers: HttpHeaders): (?Int64, Bool) {
    var isChunked = false
    if (let Some(teHv) <- headers.getInternal("transfer-encoding")) {
        checkTe(teHv)
        isChunked = true
    }

    var contentLength: ?Int64 = None
    if (let Some(clHv) <- headers.getInternal("content-length")) {
        if (isChunked) {
            throw HttpException("The content-length and transfer-encoding can not be set together.")
        }
        let cl = checkCl(clHv, headers)
        contentLength = cl
    }
    return (contentLength, isChunked)
}

func checkTe(hv: HeaderValue): Unit {
    if (hv.isSingle()) {
        checkLastTeStr(hv.single)
    } else {
        checkLastTeStr(hv.extra[hv.extra.size - 1])
        for (i in 0..hv.extra.size - 1) {
            checkOtherTeStr(hv.extra[i])
        }
        checkOtherTeStr(hv.single)
    }
}

func checkLastTeStr(te: String): Unit {
    let teStr = Str(te)
    let v = teStr.splitLast(SYMBOL_COMMA)
    if (v.trim() != Str("chunked")) {
        throw HttpException("The last value of transfer-encoding must be chunked.")
    }
    checkOtherTeStr(te[..te.size - v.size])
}

func checkOtherTeStr(te: String): Unit {
    let teStr = Str(te)
    if (!teStr.splitAllMatch(SYMBOL_COMMA, {v => v.trim() != Str("chunked")})) {
        throw HttpException("Chunked should not be set more than once.")
    }
}

private func hasNoLeadingZeros(s: String): Bool {
    if (s.isEmpty()) {
        return false
    }
    
    if (s == "0" || s[0] == b'-') {
        return true
    }

    if (s[0] < b'1') {
        return false
    }
    return true
}

func checkCl(hv: HeaderValue, headers: HttpHeaders): Int64 {
    let clStr = checkFirstClStr(hv.single) ?? throw HttpException("The content-length invalid.")
    if (!hasNoLeadingZeros(clStr.toString())) {
        throw HttpException("The content-length has leading zeros.")
    }
    let cl = Int64.tryParse(clStr.toString()) ?? throw HttpException("The content-length invalid.")
    if (cl < 0) {
        throw HttpException("The content-length should not be negative.")
    }
    for (i in 0..hv.extra.size) {
        checkOtherClStr(hv.extra[i], clStr)
    }
    headers.set("content-length", clStr.toString()) // a new String
    return cl
}

func checkCl(headers: HttpHeaders): ?Int64 {
    if (let Some(hv) <- headers.getInternal("content-length")) {
        return checkCl(hv, headers)
    }
    return None
}

func checkFirstClStr(cl: String): ?Str {
    let clStr = Str(cl)
    let first = clStr.splitFirst(SYMBOL_COMMA) ?? return None
    let trim = first.trim()
    if (first.size + 1 < cl.size) {
        checkOtherClStr(cl[first.size + 1..], trim)
    }
    return trim
}

func checkOtherClStr(cl: String, first: Str): Unit {
    if (!Str(cl).splitAllMatch(SYMBOL_COMMA, {v => v.trim() == first})) {
        throw HttpException("The content-length should not has different values.")
    }
}

func checkConnection(headers: HttpHeaders): Unit {
    // if some value in header names and hopByHop list, but not in connection values, add it to connection values
    for (i in 0..HopByHopHeaders.size where !headers.getInternal(HopByHopHeaders[i]).isNone()) {
        if (let Some(hv) <- headers.getInternal("connection")) {
            if (hv.splitAnyMatch(SYMBOL_COMMA, HopByHopHeaders[i])) {
                return
            }
        }
        headers.add("connection", HopByHopHeaders[i])
    }

    // if there is not "close", "keep-alive", "upgrade" in connection values, add "keep-alive" to connection values
    match (headers.getInternal("connection")) {
        case Some(hv) =>
            if (hv.splitAnyMatch(SYMBOL_COMMA, "close") || hv.splitAnyMatch(SYMBOL_COMMA, "keep-alive") || // cjlint-ignore !G.EXP.03
                hv.splitAnyMatch(SYMBOL_COMMA, "upgrade")) { // cjlint-ignore !G.EXP.03
                return
            }
        case _ => ()
    }
    headers.add("connection", "keep-alive")
}

func checkTrailerInHeader(headers: HttpHeaders, isReq: Bool): Unit {
    if (let Some(hv) <- headers.getInternal("trailer")) {
        let ret = hv.splitAllMatch(SYMBOL_COMMA, {v => !TrailerExcludeList.contains(v)})
        if (!ret) {
            throw chooseException("Malformed trailer field in header.", isReq)
        }
    }
}

func checkTrailer(response: HttpResponse): Unit {
    if (let Some(tr) <- response._trailers) {
        if (let Some(hv) <- response.headers.getInternal("trailer")) {
            checkTrailer(tr, hv)
        } else {
            response._trailers = None
        }
    }
}

func checkTrailer(responseBuilder: HttpResponseBuilder): Unit {
    if (let Some(tr) <- responseBuilder._trailers) {
        if (let Some(hv) <- responseBuilder.headers.getInternal("trailer")) {
            checkTrailer(tr, hv)
        } else {
            responseBuilder._trailers = None
        }
    }
}

func checkTrailer(request: HttpRequest): Unit {
    if (let Some(tr) <- request._trailers) {
        if (let Some(hv) <- request.headers.getInternal("trailer")) {
            checkTrailer(tr, hv)
        } else {
            request._trailers = None
        }
    }
}

func checkTrailer(trailer: HttpHeaders, header: HeaderValue) {
    let iter = trailer.map.iterator()
    while (let Some((n, _)) <- iter.next()) {
        if (!header.splitAnyMatch(SYMBOL_COMMA, n)) {
            iter.remove()
        }
    }
}

func checkExpect(request: HttpRequest): Unit {
    if (let Some(hv) <- request.headers.getInternal("expect")) {
        if (!hv.splitAllMatch(SYMBOL_COMMA, {str => str == Str("100-continue")})) {
            throw HttpException("The only expectation supported yet is 100-continue.")
        }
        if (request.body is HttpEmptyBody) {
            throw HttpException(
                "Must not generate a 100-continue expectation in a request that does not include content.")
        }
        request.expectContinuation = true
    }
}

func expected100Continue(headers: HttpHeaders): Bool {
    let hv = headers.getInternal("expect") ?? return false
    if (!hv.splitAllMatch(SYMBOL_COMMA, {v => v == Str("100-continue")})) {
        throw HttpStatusException(HttpStatusCode.STATUS_EXPECTATION_FAILED, "Invalid value for expect.")
    }
    return true
}

// for client only
func checkUserAgent(version: Protocol, headers: HttpHeaders): Unit {
    if (headers.getInternal("user-agent").isNone()) {
        match (version) {
            case HTTP1_1 | HTTP1_0 => headers.set("user-agent", "CANGJIEUSERAGENT_1_1")
            case HTTP2_0 => headers.set("user-agent", "CANGJIEUSERAGENT_2_0")
            case _ => throw HttpException("Protocol unknown.")
        }
    }
}

func splitValuesByComma(values: Collection<String>): ArrayList<String> {
    let res = ArrayList<String>()
    for (value in values) {
        for (v in value.split(",")) {
            v.trimAscii() |> {
                s: String => if (!s.isEmpty()) {
                    res.add(s)
                }
            }
        }
    }
    return res
}

/**
 * ThreadContext provides convenience static methods to get/set ThreadLocal variables.
 */
class ThreadContext {
    private static let _connId = ThreadLocal<UInt64>()

    mut static prop connId: ?UInt64 {
        get() {
            _connId.get()
        }
        set(v) {
            _connId.set(v)
        }
    }
}

func min(a: Int64, b: Int64): Int64 {
    if (a < b) {
        return a
    }
    return b
}

func max(a: UInt32, b: UInt32): UInt32 {
    if (a > b) {
        return a
    }
    return b
}

func sanitizeForLog(data: Array<UInt8>, maxLen!: Int64 = DEFAULT_LOG_LENTH): String {
    let result = StringBuilder()
    let len = min(data.size, maxLen)
    for (i in 0..len) {
        let b = data[i]
        if (b >= 0x20 && b != 0x7F) { // 0x20=SPACE (space character), 0x7F=DEL (delete character)
            result.append(Rune(b))
        } else {
            result.append('.')
        }
    }
    let sanitized = result.toString()
    if (data.size > maxLen) {
        return "${sanitized}... (truncated, ${data.size} bytes total)"
    }
    return sanitized
}