/*
 * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
 * This source file is part of the Cangjie project, licensed under Apache-2.0
 * with Runtime Library Exception.
 *
 * See https://cangjie-lang.cn/pages/LICENSE for license information.
 */

package stdx.net.http

import stdx.crypto.common.*

const MAXUINT16 = 65535
// decide fin rsv and opcode
const FIN: UInt8 = 128 // 1000 0000
const RSV1: UInt8 = 64 // 0100 0000
const RSV2: UInt8 = 32 // 0010 0000
const RSV3: UInt8 = 16 // 0001 0000
const OPCODE: UInt8 = 15 // 0000 1111
const CONTINUATIONCODE: UInt8 = 0
const TEXTCODE: UInt8 = 1
const BINARYCODE: UInt8 = 2
const CLOSECODE: UInt8 = 8
const PINGCODE: UInt8 = 9
const PONGCODE: UInt8 = 10
// decide the frame length except payload
const PAYLOADLEN126: UInt8 = 126
const PAYLOADLEN127: UInt8 = 127
const MASK: UInt8 = 128

/**
 * ws-frame = frame-fin ; 1 bit in length
 *            frame-rsv1 ; 1 bit in length
 *            frame-rsv2 ; 1 bit in length
 *            frame-rsv3 ; 1 bit in length
 *            frame-opcode ; 4 bits in length
 *            frame-masked ; 1 bit in length
 *            frame-payload-length ; either 7, 7+16,
 *                                 ; or 7+64 bits in
 *                                 ; length
 *            [ frame-masking-key ] ; 32 bits in length
 *            frame-payload-data ; n*8 bits in
 *                               ; length, where
 *                               ; n >= 0
 * RFC 6455 5.2.
 */
public class WebSocketFrame {
    let _fin: Bool
    let _frameType: WebSocketFrameType
    var _payload: ?Array<UInt8> = None
    var rsv1: Bool = false
    var rsv2: Bool = false
    var rsv3: Bool = false
    var mask: Bool = false
    // 7 bit payload len
    var payloadLen: UInt8 = 0
    // frame-payload-length
    var payloadLength: Int64 = 0
    // 32 bit frame-masking-key
    let maskingKey: Array<UInt8> = Array<UInt8>(4, repeat: 0)

    /**
     * the fin and opcode indicate whether the frame is a message's fragment
     */
    public prop fin: Bool {
        get() {
            _fin
        }
    }

    public prop frameType: WebSocketFrameType {
        get() {
            _frameType
        }
    }

    /**
     * the payload data of the frame
     */
    public prop payload: Array<UInt8> {
        get() {
            match (_payload) {
                case None => Array<UInt8>()
                case Some(value) => value
            }
        }
    }

    init(fin: Bool, frameType: WebSocketFrameType) {
        _fin = fin
        _frameType = frameType
    }
}

/**
 * for read
 */
func toWebSocketFrameFromFirstTwoBytes(bytes: Array<UInt8>): WebSocketFrame {
    // FIN: 1 bit,
    // indicates that this is the final fragment in a message.
    // RFC 6455 5.2.
    let fin = (bytes[0] & FIN) == FIN
    // Opcode: 4 bits
    // defines the interpretation of the "Payload data".
    // * %x0 denotes a continuation frame
    // * %x1 denotes a text frame
    // * %x2 denotes a binary frame
    // * %x3-7 are reserved for further non-control frames
    // * %x8 denotes a connection close
    // * %x9 denotes a ping
    // * %xA denotes a pong
    // * %xB-F are reserved for further control frames
    // RFC 6455 5.2.
    let opcode: UInt8 = bytes[0] & OPCODE
    let frameType = match (opcode) {
        case 0 => ContinuationWebFrame
        case 1 => TextWebFrame
        case 2 => BinaryWebFrame
        case 8 => CloseWebFrame
        case 9 => PingWebFrame
        case 10 => PongWebFrame
        case _ => UnknownWebFrame
    }
    let websocketFrame = WebSocketFrame(fin, frameType)
    // RSV1, RSV2, RSV3: 1 bit each
    // must be 0 unless an extension is negotiated that defines meanings for non-zero values.
    // RFC 6455 5.2.
    if ((bytes[0] & RSV1) == RSV1) {
        websocketFrame.rsv1 = true
    }
    if ((bytes[0] & RSV2) == RSV2) {
        websocketFrame.rsv2 = true
    }
    if ((bytes[0] & RSV3) == RSV3) {
        websocketFrame.rsv3 = true
    }
    // Mask: 1 bit
    // defines whether the payload data is masked
    // if the mask bit is set to 1, there is a 32-bit value masking-key
    // RFC 6455 5.2.
    if ((bytes[1] & MASK) == MASK) {
        websocketFrame.mask = true
    }
    // Payload len: 7 bit
    // RFC 6455 5.2.
    websocketFrame.payloadLen = PAYLOADLEN127 & bytes[1]

    return websocketFrame
}

/**
 * for write
 */
func toWebSocketFrameBytesExceptPayload(fin: Bool, frameType: WebSocketFrameType, payloadLength: Int64, isClient: Bool): Array<UInt8> {
    // ws-frame contains a maximum of 14 bytes except payload
    let array = Array<UInt8>(14, repeat: 0)
    if (fin) {
        array[0] = array[0] | FIN
    }
    // extension is not supported yet, rsv must be 0
    match (frameType) {
        case ContinuationWebFrame => array[0] = array[0] | CONTINUATIONCODE
        case TextWebFrame => array[0] = array[0] | TEXTCODE
        case BinaryWebFrame => array[0] = array[0] | BINARYCODE
        case CloseWebFrame => array[0] = array[0] | CLOSECODE
        case PingWebFrame => array[0] = array[0] | PINGCODE
        case PongWebFrame => array[0] = array[0] | PONGCODE
        case _ => ()
    }

    match {
        // if 0-125, that is the payload length
        // RFC 6455 5.2.
        case payloadLength <= 125 =>
            array[1] = array[1] | UInt8(payloadLength)
            if (isClient) {
                array[1] = array[1] | MASK
                generateMaskingKey().copyTo(array, 0, 2, 4)
                return array[..6]
            }
            return array[..2]
        // if 126, the following 2 bytes interpreted as a
        // 16-bit unsigned integer are the payload length
        // RFC 6455 5.2.
        case payloadLength >= 126 && payloadLength <= MAXUINT16 =>
            array[1] = array[1] | PAYLOADLEN126
            let payloadLengthArr = fromInt64(payloadLength, 2)
            payloadLengthArr.copyTo(array, 0, 2, 2)
            if (isClient) {
                array[1] = array[1] | MASK
                generateMaskingKey().copyTo(array, 0, 4, 4)
                return array[..8]
            }
            return array[..4]
        // if 127, the following 8 bytes interpreted as a
        // 64-bit unsigned integer are the payload length
        // RFC 6455 5.2.
        case _ =>
            array[1] = array[1] | PAYLOADLEN127
            let payloadLengthArr = fromInt64(payloadLength, 8)
            payloadLengthArr.copyTo(array, 0, 2, 8)
            if (isClient) {
                array[1] = array[1] | MASK
                generateMaskingKey().copyTo(array, 0, 10, 4)
                return array[..14]
            }
            return array[..10]
    }
}

public enum WebSocketFrameType <: Equatable<WebSocketFrameType> & ToString {
    ContinuationWebFrame
    | TextWebFrame
    | BinaryWebFrame
    | CloseWebFrame
    | PingWebFrame
    | PongWebFrame
    | UnknownWebFrame

    public override func toString(): String {
        match (this) {
            case ContinuationWebFrame => return "ContinuationWebFrame"
            case TextWebFrame => return "TextWebFrame"
            case BinaryWebFrame => return "BinaryWebFrame"
            case CloseWebFrame => return "CloseWebFrame"
            case PingWebFrame => return "PingWebFrame"
            case PongWebFrame => return "PongWebFrame"
            case UnknownWebFrame => return "UnknownWebFrame"
        }
    }

    public override operator func ==(that: WebSocketFrameType): Bool {
        match ((this, that)) {
            case (ContinuationWebFrame, ContinuationWebFrame) => true
            case (TextWebFrame, TextWebFrame) => true
            case (BinaryWebFrame, BinaryWebFrame) => true
            case (CloseWebFrame, CloseWebFrame) => true
            case (PingWebFrame, PingWebFrame) => true
            case (PongWebFrame, PongWebFrame) => true
            case (UnknownWebFrame, UnknownWebFrame) => true
            case _ => false
        }
    }
    public override operator func !=(that: WebSocketFrameType): Bool {
        return !(this == that)
    }
}

/**
 * the masking key must be derived from a strong source of entropy
 */
func generateMaskingKey(): Array<UInt8> {
    let arr = Array<UInt8>(4, repeat: 0)
    getGlobalCryptoKit().getRandomGen().nextBytes(arr)
    arr
}