/*
 * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
 * This source file is part of the Cangjie project, licensed under Apache-2.0
 * with Runtime Library Exception.
 *
 * See https://cangjie-lang.cn/pages/LICENSE for license information.
 */

package stdx.crypto.keys

import stdx.crypto.digest.*
import stdx.crypto.common.*
import std.sync.AtomicReference

// GM/T 0009-2012
let defaultUid: Array<Byte> = [0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37,
    0x38]

public class SM2PrivateKey <: PrivateKey {
    var pkey = CPointer<UInt64>()
    init(blob: DerBlob) {
        pkey = loadPrivateKey(NID_sm2, blob.body)
        if (pkey.isNull()) {
            throw CryptoException("Load private key failed.")
        }
    }

    public init() {
        pkey = unsafe { generateSM2Key() }
        if (pkey.isNull()) {
            throw CryptoException("Init private key failed.")
        }
    }

    public func encodeToDer(): DerBlob {
        var content = getSM2PrivateKeyDer(pkey)
        keepAlive(this)
        DerBlob(content)
    }

    public func encodeToDer(password!: ?String): DerBlob {
        var derBlob: DerBlob = encodeToDer()
        match (password) {
            case Some(password) => try {
                var key = GeneralPrivateKey.decodeDer(derBlob)
                key.encodeToDer(password: password)
            } catch (e: CryptoException) {
                throw CryptoException(e.message)
            }
            case None => derBlob
        }
    }

    public func encodeToPem(): PemEntry {
        PemEntry(PemEntry.LABEL_EC_PRIVATE_KEY, encodeToDer())
    }

    public func encodeToPem(password!: ?String): PemEntry {
        PemEntry(PemEntry.LABEL_EC_PRIVATE_KEY, encodeToDer(password: password))
    }

    public static func decodeDer(blob: DerBlob): SM2PrivateKey {
        SM2PrivateKey(blob)
    }

    public static func decodeDer(blob: DerBlob, password!: ?String): SM2PrivateKey {
        match (password) {
            case Some(password) => try {
                var priKey = GeneralPrivateKey.decodeDer(blob, password: password)
                SM2PrivateKey(priKey.encodeToDer())
            } catch (e: CryptoException) {
                throw CryptoException(e.message)
            }
            case None => SM2PrivateKey(blob)
        }
    }

    public static func decodeFromPem(text: String): SM2PrivateKey {
        decodeFromPem(text, password: None)
    }

    public static func decodeFromPem(
        text: String,
        password!: ?String
    ): SM2PrivateKey {
        try {
            var priKey = GeneralPrivateKey.decodeFromPem(text, password: password)
            SM2PrivateKey(priKey.encodeToDer())
        } catch (e: CryptoException) {
            throw CryptoException(e.message)
        }
    }

    public func sign(data: Array<Byte>): Array<Byte> {
        keepAlive(this)
        return signSM2(pkey, data, defaultUid)
    }

    public func decrypt(input: Array<Byte>): Array<Byte> {
        keepAlive(this)
        return decryptSM2(pkey, input)
    }

    public override func toString(): String {
        "SM2 PRIVATE KEY"
    }

    ~init() {
        if (!pkey.isNull()) {
            keyFree(pkey)
        }
    }
}

public class SM2PublicKey <: PublicKey {
    var pkey = CPointer<UInt64>()

    init(blob: DerBlob) {
        pkey = loadPublicKey(NID_sm2, blob.body)
        if (pkey.isNull()) {
            throw CryptoException("Load public key failed.")
        }
    }
    public init(pri: SM2PrivateKey) {
        var content = getSM2PublicKeyDer(pri.pkey)
        pkey = loadPublicKey(NID_sm2, content)
        if (pkey.isNull()) {
            throw CryptoException("Init public key failed.")
        }
    }

    public func encodeToDer(): DerBlob {
        var content = getSM2PublicKeyDer(pkey)
        keepAlive(this)
        DerBlob(content)
    }

    public func encodeToPem(): PemEntry {
        PemEntry(PemEntry.LABEL_PUBLIC_KEY, encodeToDer())
    }

    public static func decodeDer(blob: DerBlob): SM2PublicKey {
        SM2PublicKey(blob)
    }

    public static func decodeFromPem(text: String): SM2PublicKey {
        try {
            var pubKey = GeneralPublicKey.decodeFromPem(text)
            SM2PublicKey(pubKey.encodeToDer())
        } catch (e: CryptoException | CryptoException) {
            throw CryptoException(e.message)
        }
    }

    public func verify(data: Array<Byte>, sig: Array<Byte>): Bool {
        keepAlive(this)
        return verifySM2(pkey, data, sig, defaultUid)
    }

    public func encrypt(input: Array<Byte>): Array<Byte> {
        keepAlive(this)
        return encryptSM2(pkey, input)
    }

    public override func toString(): String {
        "SM2 PUBLIC KEY"
    }

    ~init() {
        if (!pkey.isNull()) {
            keyFree(pkey)
        }
    }
}

unsafe func generateSM2Key() {
    var pkey = CPointer<UInt64>()
    var ctx = keyCtxNewId(NID_sm2, CPointer<UInt64>())
    if (ctx.isNull()) {
        throw CryptoException("Generate keypair failed, create key ctx error.")
    }
    try (ppkey = LibC.malloc<CPointer<UInt64>>().asResource()) {
        if (ppkey.value.isNull()) {
            throw CryptoException("Generate key failed, paramgen init error, malloc failed.")
        }

        if (paramgenInit(ctx) != 1) {
            throw CryptoException("Generate key failed, paramgen init error.")
        }
        if (keygenInit(ctx) != 1) {
            throw CryptoException("Generate key failed, keygen init error.")
        }
        if (setCurveNid(ctx, NID_sm2) <= 0) {
            throw CryptoException("Generate key failed, set curve nid error.")
        }
        ppkey.value.write(pkey)
        if (keygen(ctx, ppkey.value) != 1) {
            throw CryptoException("Generate keypair failed, keygen error.")
        }
        pkey = ppkey.value.read()
    } finally {
        keyCtxFree(ctx)
    }
    pkey
}

func getSM2PrivateKeyDer(pkey: CPointer<UInt64>) {
    unsafe {
        var keySize = getSize(pkey)
        var readPrivateKeyBuf = Array<Byte>(Int64(keySize * 8), repeat: 0)
        let readPrivateKey: CPointerHandle<Byte> = acquireArrayRawData(readPrivateKeyBuf)
        var readPriLen: Int32
        var readPriPtr = LibC.malloc<CPointer<Byte>>()
        if (readPriPtr.isNull()) {
            releaseArrayRawData(readPrivateKey)
            throw CryptoException("Memory allocation failed.")
        }
        readPriPtr.write(readPrivateKey.pointer)
        try {
            readPriLen = privateKey2d(pkey, readPriPtr)
            if (readPriLen < 0 || Int64(readPriLen) > readPrivateKeyBuf.size) {
                throw CryptoException("Fail to load private key.")
            }
        } finally {
            LibC.free(readPriPtr)
            releaseArrayRawData(readPrivateKey)
        }
        readPrivateKeyBuf[0..Int64(readPriLen)]
    }
}

func getSM2PublicKeyDer(pkey: CPointer<UInt64>) {
    unsafe {
        var keySize = getSize(pkey)
        var readBufPublicKey = Array<Byte>(Int64(keySize * 8), repeat: 0)
        let read: CPointerHandle<Byte> = acquireArrayRawData(readBufPublicKey)
        var readLen: Int32
        var readPtr = LibC.malloc<CPointer<Byte>>()
        if (readPtr.isNull()) {
            releaseArrayRawData(read)
            throw CryptoException("Memory allocation failed.")
        }
        readPtr.write(read.pointer)
        try {
            readLen = pubKey2d(pkey, readPtr)
            if (readLen < 0 || Int64(readLen) > readBufPublicKey.size) {
                throw CryptoException("Fail to load public key.")
            }
        } finally {
            LibC.free(readPtr)
            releaseArrayRawData(read)
        }
        readBufPublicKey[0..Int64(readLen)]
    }
}

func encryptSM2(pkey: CPointer<UInt64>, input: Array<Byte>) {
    if (input.size == 0) {
        return Array<Byte>()
    }
    unsafe {
        let ctx: CPointer<UInt64> = keyCtxNew(pkey, CPointer<UInt64>())
        if (ctx.isNull()) {
            throw CryptoException("Encrypt failed, create key ctx error.")
        }
        var output = Array<Byte>(input.size + 256, repeat: 0)
        var inputBlock: CPointerHandle<Byte> = acquireArrayRawData(input)
        var outputBlock: CPointerHandle<Byte> = acquireArrayRawData(output)
        var outlenSize = 0
        try (outlenPtr = LibC.malloc<UIntNative>().asResource()) {
            if (outlenPtr.value.isNull()) {
                throw CryptoException("Encrypt failed, encrypt init error, malloc failed.")
            }

            if (encryInit(ctx) != 1) {
                throw CryptoException("Encrypt failed, encrypt init error.")
            }
            if (encrypt(ctx, CPointer<Byte>(), outlenPtr.value, inputBlock.pointer, UIntNative(input.size)) != 1) {
                throw CryptoException("Encrypt prepare failed.")
            }
            if (encrypt(ctx, outputBlock.pointer, outlenPtr.value, inputBlock.pointer, UIntNative(input.size)) != 1) {
                throw CryptoException("Encrypt failed.")
            }
            outlenSize = Int64(outlenPtr.value.read())
        } finally {
            releaseArrayRawData(inputBlock)
            releaseArrayRawData(outputBlock)
            keyCtxFree(ctx)
        }
        output[..outlenSize]
    }
}

func decryptSM2(pkey: CPointer<UInt64>, input: Array<Byte>) {
    if (input.size == 0) {
        return Array<Byte>()
    }
    unsafe {
        let ctx: CPointer<UInt64> = keyCtxNew(pkey, CPointer<UInt64>())
        if (ctx.isNull()) {
            throw CryptoException("Decrypt failed, create key ctx error.")
        }
        var output = Array<Byte>(input.size, repeat: 0)
        var inputBlock: CPointerHandle<Byte> = acquireArrayRawData(input)
        var outputBlock: CPointerHandle<Byte> = acquireArrayRawData(output)
        var reSize = 0
        try (outlenPtr = LibC.malloc<UIntNative>().asResource()) {
            if (outlenPtr.value.isNull()) {
                throw CryptoException("Decrypt failed, decrypt init error, malloc failed.")
            }

            if (decryptInit(ctx) != 1) {
                throw CryptoException("Decrypt failed, decrypt init error.")
            }
            if (decrypt(ctx, CPointer<Byte>(), outlenPtr.value, inputBlock.pointer, UIntNative(input.size)) != 1) {
                throw CryptoException("Decrypt prepare failed.")
            }
            if (decrypt(ctx, outputBlock.pointer, outlenPtr.value, inputBlock.pointer, UIntNative(input.size)) != 1) {
                throw CryptoException("Decrypt failed.")
            }
            reSize = Int64(outlenPtr.value.read())
        } finally {
            releaseArrayRawData(inputBlock)
            releaseArrayRawData(outputBlock)
            keyCtxFree(ctx)
        }
        output[..reSize]
    }
}

func signSM2(pkey: CPointer<UInt64>, data: Array<Byte>, uid: Array<Byte>) {
    unsafe {
        let pData: CPointerHandle<Byte> = acquireArrayRawData(data)
        var buffSize = 128
        var sig: Array<UInt8> = Array<UInt8>(Int64(buffSize), repeat: 0)
        let sigHandle: CPointerHandle<Byte> = acquireArrayRawData(sig)
        var ctx = CPointer<UInt64>()
        var mdctx = CPointer<UInt64>()
        try (sizePtr = LibC.malloc<Int64>().asResource()) {
            if (sizePtr.value.isNull()) {
                throw CryptoException("Sign failed, digest sign init error, malloc failed.")
            }

            ctx = generateCtx(pkey)
            mdctx = createMdCtx()
            setPkeyCtx(ctx, mdctx, uid)
            if (digestSignInit(mdctx, CPointer<UInt64>(), sm3(), CPointer<UInt64>(), pkey) != 1) {
                throw CryptoException("Sign failed, digest sign init error.")
            }
            if (digestSignUpdate(mdctx, pData.pointer, UIntNative(data.size)) != 1) {
                throw CryptoException("Sign failed, digest sign update error.")
            }
            if (digestSignFinal(mdctx, CPointer<Byte>(), sizePtr.value) != 1) {
                throw CryptoException("Sign prepare failed.")
            }
            if (digestSignFinal(mdctx, sigHandle.pointer, sizePtr.value) != 1) {
                throw CryptoException("Sign failed.")
            }
            buffSize = sizePtr.value.read()
        } finally {
            releaseArrayRawData(pData)
            releaseArrayRawData(sigHandle)
            freeData(ctx, mdctx)
        }
        sig[0..buffSize]
    }
}

func verifySM2(pkey: CPointer<UInt64>, data: Array<Byte>, sig: Array<Byte>, uid: Array<Byte>) {
    unsafe {
        let pData: CPointerHandle<Byte> = acquireArrayRawData(data)
        let sigData: CPointerHandle<Byte> = acquireArrayRawData(sig)
        var resFinal: Int32 = 0
        var ctx = CPointer<UInt64>()
        var mdctx = CPointer<UInt64>()
        try {
            ctx = generateCtx(pkey)
            mdctx = createMdCtx()
            setPkeyCtx(ctx, mdctx, uid)
            if (digestVerifyInit(mdctx, CPointer<UInt64>(), sm3(), CPointer<UInt64>(), pkey) != 1) {
                throw CryptoException("Verify failed, digest verify init error.")
            }
            if (digestVerifyUpdate(mdctx, pData.pointer, UIntNative(data.size)) != 1) {
                throw CryptoException("Verify failed, digest verify update error.")
            }
            resFinal = digestVerifyFinal(mdctx, sigData.pointer, UIntNative(sig.size))
            if (resFinal != 1 && resFinal != 0) {
                throw CryptoException("Verify failed.")
            }
        } finally {
            releaseArrayRawData(pData)
            releaseArrayRawData(sigData)
            freeData(ctx, mdctx)
        }
        return resFinal == 1
    }
}

func generateCtx(pkey: CPointer<UInt64>): CPointer<UInt64> {
    unsafe {
        let ctx: CPointer<UInt64> = keyCtxNew(pkey, CPointer<UInt64>())
        if (ctx.isNull()) {
            throw CryptoException("Create ctx failed.")
        }
        ctx
    }
}

func createMdCtx(): CPointer<UInt64> {
    unsafe {
        var mdctx: CPointer<UInt64> = mdCtxNew()
        if (mdctx.isNull()) {
            throw CryptoException("Create md ctx failed.")
        }
        mdctx
    }
}

func setPkeyCtx(ctx: CPointer<UInt64>, mdctx: CPointer<UInt64>, uid: Array<Byte>): Unit {
    unsafe {
        let uidData: CPointerHandle<Byte> = acquireArrayRawData(uid)
        try {
            let resSet = keyCtxSetId(ctx, uidData.pointer, uid.size)
            if (resSet <= 0) {
                throw CryptoException("set md ctx failed.")
            }
        } finally {
            releaseArrayRawData(uidData)
        }
        setKeyCtx(mdctx, ctx)
    }
}

func freeData(ctx: CPointer<UInt64>, mdctx: CPointer<UInt64>): Unit {
    unsafe {
        keyCtxFree(ctx)
        mdCtxFree(mdctx)
    }
}

let _privateKeySM2Phantom = Object()
let _sm2PrivateKey = AtomicReference<Object>(_privateKeySM2Phantom)

func keepAlive(o: SM2PrivateKey) {
    _sm2PrivateKey.store(o)
    _sm2PrivateKey.store(_privateKeySM2Phantom)
}

let _publicKeySM2Phantom = Object()
let _sm2PublicKey = AtomicReference<Object>(_publicKeySM2Phantom)

func keepAlive(o: SM2PublicKey) {
    _sm2PublicKey.store(o)
    _sm2PublicKey.store(_publicKeySM2Phantom)
}