/*
* Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
* This source file is part of the Cangjie project, licensed under Apache-2.0
* with Runtime Library Exception.
*
* See https://cangjie-lang.cn/pages/LICENSE for license information.
*/
package stdx.net.http
import stdx.encoding.base64.toBase64String
import stdx.encoding.url.{URL, Form}
import std.collection.*
import std.convert.Parsable
import std.fs.*
import std.io.*
import stdx.log.*
import std.process.*
import std.net.*
import std.time.*
import std.sync.*
const INT64_TO_HEX: String = "0123456789abcdef"
const MASK_I64: Int64 = 0x0f
const DEFAULT_LOG_LENTH: Int64 = 256
let shortDayOfWeekForm: Array<Str> = [Str("Sun"), Str("Mon"), Str("Tue"), Str("Wed"), Str("Thu"), Str("Fri"), Str("Sat")]
let shortMonthForm: Array<Str> = [Str("Jan"), Str("Feb"), Str("Mar"), Str("Apr"), Str("May"), Str("Jun"), Str("Jul"),
Str("Aug"), Str("Sep"), Str("Oct"), Str("Nov"), Str("Dec")]
let stringTable: HashMap<Str, String> = HashMap<Str, String>(
[
(Str("host"), "host"),
(Str("connection"), "connection"),
(Str("upgrade-insecure-requests"), "upgrade-insecure-requests"),
(Str("user-agent"), "user-agent"),
(Str("accept"), "accept"),
(Str("accept-encoding"), "accept-encoding"),
(Str("accept-language"), "accept-language"),
(Str("authority"), "authority"),
(Str("method"), "method"),
(Str("path"), "path"),
(Str("scheme"), "scheme"),
(Str("status"), "status"),
(Str("accept-charset"), "accept-charset"),
(Str("accept-ranges"), "accept-ranges"),
(Str("access-control-allow-origin"), "access-control-allow-origin"),
(Str("age"), "age"),
(Str("authorization"), "authorization"),
(Str("cache-control"), "cache-control"),
(Str("content-disposition"), "content-disposition"),
(Str("content-encoding"), "content-encoding"),
(Str("content-language"), "content-language"),
(Str("content-length"), "content-length"),
(Str("content-location"), "content-location"),
(Str("content-range"), "content-range"),
(Str("content-type"), "content-type"),
(Str("cookie"), "cookie"),
(Str("date"), "date"),
(Str("etag"), "etag"),
(Str("expect"), "expect"),
(Str("expires"), "expires"),
(Str("from"), "from"),
(Str("if-match"), "if-match"),
(Str("if-modified-since"), "if-modified-since"),
(Str("if-none-match"), "if-none-match"),
(Str("if-range"), "if-range"),
(Str("if-unmodified-since"), "if-unmodified-since"),
(Str("last-modified"), "last-modified"),
(Str("link"), "link"),
(Str("location"), "location"),
(Str("max-forwards"), "max-forwards"),
(Str("proxy-authenticate"), "proxy-authenticate"),
(Str("proxy-authorization"), "proxy-authorization"),
(Str("range"), "range"),
(Str("referer"), "referer"),
(Str("refresh"), "refresh"),
(Str("retry-after"), "retry-after"),
(Str("server"), "server"),
(Str("set-cookie"), "set-cookie"),
(Str("strict-transport-security"), "strict-transport-security"),
(Str("transfer-encoding"), "transfer-encoding"),
(Str("vary"), "vary"),
(Str("via"), "via"),
(Str("www-authenticate"), "www-authenticate")
]
)
func getString(str: Str): String {
if (let Some(object) <- stringTable.get(str)) {
return object
}
return str.toString()
}
func toLowerCaseStr(str: String): Str {
if (!str.hasUpper()) {
return Str(str)
}
if (let Some(v) <- stringTable.get(Str(str))) {
return Str(v)
} else {
return Str(str.toAsciiLower())
}
}
extend Int64 {
func toHexString(): String {
var num: Int64 = this
var arrSize: Int64 = 0
for (_ in 0..16) {
if (num > 0) {
arrSize++
num = num >> 4
} else {
break
}
}
let arr = Array<Byte>(arrSize, repeat: b'0')
num = this
for (i in arrSize - 1..-1 : -1) {
if (num > 0) {
let item = num & MASK_I64
arr[i] = INT64_TO_HEX[Int64(item)]
num = num >> 4
} else {
break
}
}
return String.fromUtf8(arr)
}
static func fromHexStr(hexRaw: Str): Int64 {
if (hexRaw.isEmpty()) {
throw IllegalArgumentException("Invalid hex string.")
}
if (hexRaw.size > 16 || (hexRaw.size == 16 &&
(fromHexByte(hexRaw[0]) ?? throw IllegalArgumentException("Invalid hex string.")) > 7)) { // cjlint-ignore !G.EXP.03
throw IllegalArgumentException("The hex number is too big.")
}
var number: Int64 = 0
for (i in 0..hexRaw.size) {
match (fromHexByte(hexRaw[i])) {
case Some(v) =>
number <<= 4
number |= v
case None => throw IllegalArgumentException("Invalid hex string.")
}
}
return number
}
static func fromHexByte(h: Byte): ?Int64 {
return match {
case b'0' <= h && h <= b'9' => Int64(h - b'0')
case b'a' <= h && h <= b'f' => Int64(h - b'a' + 10)
case b'A' <= h && h <= b'F' => Int64(h - b'A' + 10)
case _ => None
}
}
}
/* if status code of response is 3xx besides 304, client need do redirect */
let needRedirect = {code: UInt16 => code >= 300 && code < 400 && code != HttpStatusCode.STATUS_NOT_MODIFIED}
/* remove . .. // in path */
func canonicalPath(path: String): String {
if (path.isEmpty() || path == "/") {
return path
}
let raw = unsafe { path.rawData() }
let newPath = Array<Byte>(raw.size + 1, repeat: b'/')
var partLeft = 0
var partRight = 0
var writeIdx = 1
while (partRight < raw.size) {
let b = raw[partRight]
if (b != b'/') {
partRight++
continue
}
match (path[partLeft..partRight]) {
case "" | "." => ()
case ".." => writeIdx = findPrevSlash(newPath, writeIdx) + 1
case _ =>
raw.copyTo(newPath, partLeft, writeIdx, partRight - partLeft)
writeIdx += partRight - partLeft
newPath[writeIdx] = b'/'
writeIdx++
}
partRight++
partLeft = partRight
}
if (partLeft < raw.size) {
raw.copyTo(newPath, partLeft, writeIdx, partRight - partLeft)
writeIdx += partRight - partLeft
}
return unsafe { String.fromUtf8Unchecked(newPath[..writeIdx]) }
}
func findPrevSlash(path: Array<Byte>, curIdx: Int64): Int64 {
// only start slash in path
if (curIdx == 1) {
return 0
}
// get back to prev slash
var prev = curIdx - 2
while (prev >= 0) {
if (path[prev] == b'/') {
break
}
prev--
}
return prev
}
extend String {
/**
* To determine whether values contains target.
* Values is caseInsensitive, target should consist of lower cases.
*/
func caseInsensitiveMatchOne(values: ArrayList<String>) {
for (value in values) {
if (value.toAsciiLower() == this) {
return true
}
}
return false
}
func caseInsensitiveMatchAll(values: ArrayList<String>) {
for (value in values) {
if (value.toAsciiLower() != this) {
return false
}
}
return true
}
func matchAll(values: ArrayList<String>) {
for (value in values) {
if (value != this) {
return false
}
}
return true
}
func ifEmpty(s: String): String {
return if (this.isEmpty()) {
s
} else {
this
}
}
func hasUpper(): Bool {
return Str(this).hasUpper()
}
}
extend<K, V> HashMapIterator<K, V> {
func removeIf(condition: () -> Bool): Unit {
if (condition()) {
this.remove()
}
}
}
extend<K, V> HashMap<K, V> {
func removeKeyIf(key: K, condition: () -> Bool): Unit {
if (condition()) {
this.remove(key)
}
}
}
/* map file extension to content-type */
let EXT_TYPE_MAP = HashMap<String, String>(
[
(".avif", "image/avif"),
(".css", "text/css; charset=utf-8"),
(".gif", "image/gif"),
(".htm", "text/html; charset=utf-8"),
(".html", "text/html; charset=utf-8"),
(".jpeg", "image/jpeg"),
(".jpg", "image/jpeg"),
(".js", "text/javascript; charset=utf-8"),
(".json", "application/json"),
(".mjs", "text/javascript; charset=utf-8"),
(".pdf", "application/pdf"),
(".png", "image/png"),
(".svg", "image/svg+xml"),
(".wasm", "application/wasm"),
(".webp", "image/webp"),
(".xml", "text/xml; charset=utf-8")
]
)
/* make sure duration >= zero */
func checkDuration(d: Duration): Duration {
if (d < Duration.Zero) {
return Duration.Zero
} else {
return d
}
}
func basicAuth(username: String, password: String): String {
let auth = unsafe { "${username}:${password}".rawData() }
return toBase64String(auth)
}
func parseProxyAuth(url: URL): ?String {
var auth = None<String>
if (!url.rawUserInfo.username().isEmpty()) {
auth = "Basic ${basicAuth(url.rawUserInfo.username(), url.rawUserInfo.password() ?? "")}"
}
return auth
}
// rules of boundary rfc2046#section-5.1.1
func checkBoundary(boundary: String): Bool {
if (boundary.size < 1 || boundary.size > 70) {
return false
}
if (boundary.endsWith(" ")) {
return false
}
for (c in boundary) {
if (c.isAsciiNumberOrLetter()) {
continue
}
match (c) {
case '\'' | '(' | ')' | '+' | '_' | ',' | '-' | '.' | '/' | ':' | '=' | '?' => continue
case _ => return false
}
}
return true
}
func sizeOf(inputStream: InputStream): ?Int64 {
match (inputStream) {
case s: Seekable =>
if (s.length >= 0) {
return s.length
}
return None
case _ => return None
}
}
/*
* To support some special requirements:
* 1. avoid creating and running a Timer if start is set to Duration.Max
* 2. avoid initializing a Timer using zeroValue or Option in several individual classes
* and future extensions
*/
class HttpTimer {
var timer: ?Timer = None
static let empty: HttpTimer = HttpTimer()
private init() {}
init(start!: Duration, task!: () -> Unit) {
if (start != Duration.Max) {
timer = Timer.once(start, task)
}
}
func cancel(): Unit {
if (let Some(t) <- timer) {
t.cancel()
timer = None
}
}
}
func getRFC1123String(dateArr: Array<Byte>): Str {
let t = DateTime.nowUTC()
writeTwoNum(dateArr, t.second, 41)
writeTwoNum(dateArr, t.minute, 38)
writeTwoNum(dateArr, t.hour, 35)
let i = writeYear(dateArr, t.year)
dateArr[i] = b' '
shortMonthForm[t.month.toInteger() - 1].raw.copyTo(dateArr, 0, i - 3, 3)
dateArr[i - 4] = b' '
writeTwoNum(dateArr, t.dayOfMonth, i - 6)
dateArr[i - 7] = b' '
dateArr[i - 8] = b','
shortDayOfWeekForm[t.dayOfWeek.toInteger()].raw.copyTo(dateArr, 0, i - 11, 3)
return Str(dateArr[i - 11..])
}
func writeTwoNum(dateArr: Array<Byte>, n: Int64, idx: Int64): Unit {
if (n < 0 || n > 99) {
throw IllegalArgumentException("Invalid date number: ${n}.")
}
dateArr[idx] = UInt8(n / 10) + b'0'
dateArr[idx + 1] = UInt8(n % 10) + b'0'
}
func writeYear(dateArr: Array<Byte>, year: Int64): Int64 {
var y = UInt64(year)
var i = 33
while (y > 0) {
dateArr[i] = UInt8(y % 10) + b'0'
y /= 10
i--
}
return i
}
func parseAndCheckResponseLine(line: Str): (String, UInt16) {
let version = line.splitFirst(WS)?.toString() ?? throw HttpException("Invalid response status line.")
if (version != "HTTP/1.0" && version != "HTTP/1.1") {
throw HttpException("Invalid response status line.")
}
/**
* The status code start from index 9.
*
* 0123456789abcde
* ---------------
* HTTP/1.1 200 OK
*/
let status = line.splitFirst(WS, 9)?.toString() ?? throw HttpException("Invalid response status line.")
let statusCode = match (UInt16.tryParse(status)) {
case Some(v) =>
if (v < 100 || v > 599) {
throw HttpException("Invalid response status code.")
}
v
case None => throw HttpException("Invalid response status line.")
}
return (version, statusCode)
}
func parseAndCheckRequestLine(line: Str): (String, String, String) {
// parse and check method
let method = line.splitFirst(WS) ?? throw HttpStatusException(HttpStatusCode.STATUS_BAD_REQUEST, // cjlint-ignore !G.EXP.03
"Invalid request method, request line: ${line}.")
if (method.isEmpty() || !method.byteMatches(isTokenByte)) { // cjlint-ignore !G.EXP.03
throw HttpStatusException(HttpStatusCode.STATUS_METHOD_NOT_ALLOWED, "Invalid request method: ${method}.")
}
if (method.size > MAX_METHOD_SIZE) {
throw HttpStatusException(HttpStatusCode.STATUS_BAD_REQUEST, "Invalid method: ${method}.")
}
// parse target
var i = method.size + 1
let target = line.splitFirst(WS, i) ?? throw HttpStatusException(HttpStatusCode.STATUS_BAD_REQUEST, // cjlint-ignore !G.EXP.03
"Invalid request target, request line: ${line}.")
// parse and check version
i += target.size + 1
let version = line.splitFirst(WS, i) ?? throw HttpStatusException(HttpStatusCode.STATUS_BAD_REQUEST, // cjlint-ignore !G.EXP.03
"Invalid request version, request line: ${line}.")
if (version != Str("HTTP/1.1") && version != Str("HTTP/1.0")) { // cjlint-ignore !G.EXP.03
throw HttpStatusException(HttpStatusCode.STATUS_HTTP_VERSION_NOT_SUPPORTED, "Incorrect version: ${version}.")
}
if (i + version.size != line.size) {
throw HttpStatusException(HttpStatusCode.STATUS_BAD_REQUEST,
"Invalid request line, more than three elements: ${line}.")
}
return (method.toString(), target.toString(), version.toString())
}
func parseAndCheckHeaderLine(line: Str, isReq: Bool): (String, String) {
// parse and check name
let name = line.splitFirst(SYMBOL_COLON) ?? throw chooseException("Invalid field line: ${line}.", isReq) // cjlint-ignore !G.EXP.03
if (!name.byteMatches(isTokenByte)) {
throw chooseException("Invalid field line: ${line}.", isReq)
}
if (name.size == line.size) {
throw chooseException("Invalid field line: ${line}.", isReq)
}
// parse and check value
let value = line.slice(name.size + 1).trim()
if (!value.byteMatches(checkValueBytes)) {
throw chooseException("Invalid field line: ${line}.", isReq)
}
return (getString(name), value.toString())
}
func checkValueBytes(b: Byte): Bool {
return !(b < 0x20 && b != 0x09 || b == 0x7f || b > 0xff)
}
func chooseException(msg: String, isReq: Bool): Exception {
return if (isReq) {
HttpStatusException(HttpStatusCode.STATUS_BAD_REQUEST, msg)
} else {
HttpException(msg)
}
}
func checkClAndTe(headers: HttpHeaders): (?Int64, Bool) {
var isChunked = false
if (let Some(teHv) <- headers.getInternal("transfer-encoding")) {
checkTe(teHv)
isChunked = true
}
var contentLength: ?Int64 = None
if (let Some(clHv) <- headers.getInternal("content-length")) {
if (isChunked) {
throw HttpException("The content-length and transfer-encoding can not be set together.")
}
let cl = checkCl(clHv, headers)
contentLength = cl
}
return (contentLength, isChunked)
}
func checkTe(hv: HeaderValue): Unit {
if (hv.isSingle()) {
checkLastTeStr(hv.single)
} else {
checkLastTeStr(hv.extra[hv.extra.size - 1])
for (i in 0..hv.extra.size - 1) {
checkOtherTeStr(hv.extra[i])
}
checkOtherTeStr(hv.single)
}
}
func checkLastTeStr(te: String): Unit {
let teStr = Str(te)
let v = teStr.splitLast(SYMBOL_COMMA)
if (v.trim() != Str("chunked")) {
throw HttpException("The last value of transfer-encoding must be chunked.")
}
checkOtherTeStr(te[..te.size - v.size])
}
func checkOtherTeStr(te: String): Unit {
let teStr = Str(te)
if (!teStr.splitAllMatch(SYMBOL_COMMA, {v => v.trim() != Str("chunked")})) {
throw HttpException("Chunked should not be set more than once.")
}
}
private func hasNoLeadingZeros(s: String): Bool {
if (s.isEmpty()) {
return false
}
if (s == "0" || s[0] == b'-') {
return true
}
if (s[0] < b'1') {
return false
}
return true
}
func checkCl(hv: HeaderValue, headers: HttpHeaders): Int64 {
let clStr = checkFirstClStr(hv.single) ?? throw HttpException("The content-length invalid.")
if (!hasNoLeadingZeros(clStr.toString())) {
throw HttpException("The content-length has leading zeros.")
}
let cl = Int64.tryParse(clStr.toString()) ?? throw HttpException("The content-length invalid.")
if (cl < 0) {
throw HttpException("The content-length should not be negative.")
}
for (i in 0..hv.extra.size) {
checkOtherClStr(hv.extra[i], clStr)
}
headers.set("content-length", clStr.toString()) // a new String
return cl
}
func checkCl(headers: HttpHeaders): ?Int64 {
if (let Some(hv) <- headers.getInternal("content-length")) {
return checkCl(hv, headers)
}
return None
}
func checkFirstClStr(cl: String): ?Str {
let clStr = Str(cl)
let first = clStr.splitFirst(SYMBOL_COMMA) ?? return None
let trim = first.trim()
if (first.size + 1 < cl.size) {
checkOtherClStr(cl[first.size + 1..], trim)
}
return trim
}
func checkOtherClStr(cl: String, first: Str): Unit {
if (!Str(cl).splitAllMatch(SYMBOL_COMMA, {v => v.trim() == first})) {
throw HttpException("The content-length should not has different values.")
}
}
func checkConnection(headers: HttpHeaders): Unit {
// if some value in header names and hopByHop list, but not in connection values, add it to connection values
for (i in 0..HopByHopHeaders.size where !headers.getInternal(HopByHopHeaders[i]).isNone()) {
if (let Some(hv) <- headers.getInternal("connection")) {
if (hv.splitAnyMatch(SYMBOL_COMMA, HopByHopHeaders[i])) {
return
}
}
headers.add("connection", HopByHopHeaders[i])
}
// if there is not "close", "keep-alive", "upgrade" in connection values, add "keep-alive" to connection values
match (headers.getInternal("connection")) {
case Some(hv) =>
if (hv.splitAnyMatch(SYMBOL_COMMA, "close") || hv.splitAnyMatch(SYMBOL_COMMA, "keep-alive") || // cjlint-ignore !G.EXP.03
hv.splitAnyMatch(SYMBOL_COMMA, "upgrade")) { // cjlint-ignore !G.EXP.03
return
}
case _ => ()
}
headers.add("connection", "keep-alive")
}
func checkTrailerInHeader(headers: HttpHeaders, isReq: Bool): Unit {
if (let Some(hv) <- headers.getInternal("trailer")) {
let ret = hv.splitAllMatch(SYMBOL_COMMA, {v => !TrailerExcludeList.contains(v)})
if (!ret) {
throw chooseException("Malformed trailer field in header.", isReq)
}
}
}
func checkTrailer(response: HttpResponse): Unit {
if (let Some(tr) <- response._trailers) {
if (let Some(hv) <- response.headers.getInternal("trailer")) {
checkTrailer(tr, hv)
} else {
response._trailers = None
}
}
}
func checkTrailer(responseBuilder: HttpResponseBuilder): Unit {
if (let Some(tr) <- responseBuilder._trailers) {
if (let Some(hv) <- responseBuilder.headers.getInternal("trailer")) {
checkTrailer(tr, hv)
} else {
responseBuilder._trailers = None
}
}
}
func checkTrailer(request: HttpRequest): Unit {
if (let Some(tr) <- request._trailers) {
if (let Some(hv) <- request.headers.getInternal("trailer")) {
checkTrailer(tr, hv)
} else {
request._trailers = None
}
}
}
func checkTrailer(trailer: HttpHeaders, header: HeaderValue) {
let iter = trailer.map.iterator()
while (let Some((n, _)) <- iter.next()) {
if (!header.splitAnyMatch(SYMBOL_COMMA, n)) {
iter.remove()
}
}
}
func checkExpect(request: HttpRequest): Unit {
if (let Some(hv) <- request.headers.getInternal("expect")) {
if (!hv.splitAllMatch(SYMBOL_COMMA, {str => str == Str("100-continue")})) {
throw HttpException("The only expectation supported yet is 100-continue.")
}
if (request.body is HttpEmptyBody) {
throw HttpException(
"Must not generate a 100-continue expectation in a request that does not include content.")
}
request.expectContinuation = true
}
}
func expected100Continue(headers: HttpHeaders): Bool {
let hv = headers.getInternal("expect") ?? return false
if (!hv.splitAllMatch(SYMBOL_COMMA, {v => v == Str("100-continue")})) {
throw HttpStatusException(HttpStatusCode.STATUS_EXPECTATION_FAILED, "Invalid value for expect.")
}
return true
}
// for client only
func checkUserAgent(version: Protocol, headers: HttpHeaders): Unit {
if (headers.getInternal("user-agent").isNone()) {
match (version) {
case HTTP1_1 | HTTP1_0 => headers.set("user-agent", "CANGJIEUSERAGENT_1_1")
case HTTP2_0 => headers.set("user-agent", "CANGJIEUSERAGENT_2_0")
case _ => throw HttpException("Protocol unknown.")
}
}
}
func splitValuesByComma(values: Collection<String>): ArrayList<String> {
let res = ArrayList<String>()
for (value in values) {
for (v in value.split(",")) {
v.trimAscii() |> {
s: String => if (!s.isEmpty()) {
res.add(s)
}
}
}
}
return res
}
/**
* ThreadContext provides convenience static methods to get/set ThreadLocal variables.
*/
class ThreadContext {
private static let _connId = ThreadLocal<UInt64>()
mut static prop connId: ?UInt64 {
get() {
_connId.get()
}
set(v) {
_connId.set(v)
}
}
}
func min(a: Int64, b: Int64): Int64 {
if (a < b) {
return a
}
return b
}
func max(a: UInt32, b: UInt32): UInt32 {
if (a > b) {
return a
}
return b
}
func sanitizeForLog(data: Array<UInt8>, maxLen!: Int64 = DEFAULT_LOG_LENTH): String {
let result = StringBuilder()
let len = min(data.size, maxLen)
for (i in 0..len) {
let b = data[i]
if (b >= 0x20 && b != 0x7F) { // 0x20=SPACE (space character), 0x7F=DEL (delete character)
result.append(Rune(b))
} else {
result.append('.')
}
}
let sanitized = result.toString()
if (data.size > maxLen) {
return "${sanitized}... (truncated, ${data.size} bytes total)"
}
return sanitized
}