/*
 * This file is part of the openHiTLS project.
 *
 * openHiTLS is licensed under the Mulan PSL v2.
 * You can use this software according to the terms and conditions of the Mulan PSL v2.
 * You may obtain a copy of Mulan PSL v2 at:
 *
 *     http://license.coscl.org.cn/MulanPSL2
 *
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
 * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
 * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
 * See the Mulan PSL v2 for more details.
 */

#include "hitls_build.h"
#ifdef HITLS_CRYPTO_SM9

#include "sm9.h"
#include "sm9_curve.h"
#include "sm9_pairing.h"
#include "sm9_fp.h"
#include "crypt_util_rand.h"
#include "crypt_errno.h"
#include "bsl_sal.h"
#include "bsl_bytes.h"

#include <string.h>

/***************************Compiler-Switches**********************************/

#define SM9_ALG_RND_ENABLE

#define SM9_PAIR_ENABLE

#define SM9_SIG_SYS_ENABLE
#define SM9_SIG_USR_ENABLE

#define SM9_ENC_SYS_ENABLE
#define SM9_ENC_USR_ENABLE

#define SM9_HByteLen        40
#define SM9_HWordLen        (SM9_HByteLen/4)

#define SM9_C1_ByteLen        (2*BNByteLen)

#define SM9_C3_ByteLen        (BNByteLen)

/*******************************Static Global *********************************/
static const uint8_t SM9_HID_Sign[1]  = {0x01};
static const uint8_t SM9_HID_Enc[1]   = {0x03};

/******************************************************************************/
/*                    SM3 Adaptation Layer for SM9                           */
/******************************************************************************/

// Forward declarations from SM3 library
extern int32_t CRYPT_SM3_Update(void *ctx, const uint8_t *in, uint32_t len);
extern int32_t CRYPT_SM3_Final(void *ctx, uint8_t *out, uint32_t *outLen);

static void SM9_Alg_Pair_Mont(SM9_Fp12 *pFp12_R, SM9_ECP_A *pEcp_P1, SM9_ECP2_A *pEcp2_P2)
{
    return SM9_Pairing_R_Ate(pFp12_R, pEcp_P1, pEcp2_P2);
}

void SM9_Hash_Init(SM9_Hash_Ctx *ctx)
{
    // Manual initialization following CRYPT_SM3_Init logic
    BSL_SAL_CleanseData(&ctx->sm3State, sizeof(SM9_CRYPT_SM3_Ctx));

    // GM/T 0004-2012 chapter 4.1 - SM3 initial values
    ctx->sm3State.h[0] = 0x7380166F;
    ctx->sm3State.h[1] = 0x4914B2B9;
    ctx->sm3State.h[2] = 0x172442D7;
    ctx->sm3State.h[3] = 0xDA8A0600;
    ctx->sm3State.h[4] = 0xA96F30BC;
    ctx->sm3State.h[5] = 0x163138AA;
    ctx->sm3State.h[6] = 0xE38DEE4D;
    ctx->sm3State.h[7] = 0xB0FB0E4E;
}

void SM9_Hash_Update(SM9_Hash_Ctx *ctx, const uint8_t *data, uint32_t len)
{
    CRYPT_SM3_Update(&ctx->sm3State, data, len);
}

void SM9_Hash_Final(SM9_Hash_Ctx *ctx, uint8_t *digest)
{
    uint32_t outLen = 32;
    CRYPT_SM3_Final(&ctx->sm3State, digest, &outLen);
}

void SM9_Xor(uint8_t *pbXor, uint32_t len, const uint8_t *digest, uint8_t *pbMsg)
{
    for (uint32_t i = 0; i < len; i++) {
        pbXor[i] = digest[i] ^ pbMsg[i];
    }
}

void SM9_Hash_Data(const uint8_t *data, uint32_t len, uint8_t *digest)
{
    SM9_Hash_Ctx ctx;
    SM9_Hash_Init(&ctx);
    SM9_Hash_Update(&ctx, data, len);
    SM9_Hash_Final(&ctx, digest);
}

/******************************************************************************/
/*                    Random Number Generation for SM9                       */
/******************************************************************************/

int32_t sm9_rand(uint8_t *p, uint32_t len)
{
    // Use library's cryptographic random number generator with context
    return CRYPT_RandEx(NULL, p, len);
}

void SM9_ModifyKeyRange(uint8_t *key)
{
    uint32_t bn[BNWordLen];
    SM9_Bn_ReadBytes(bn, key);
    SM9_Fn_LastRes(bn);
    bn_add_int(bn, bn, 1, BNWordLen);
    SM9_Fn_LastRes(bn);
    SM9_Fp_WriteBytes(key, bn);
    BSL_SAL_CleanseData(bn, sizeof(bn));
}
/******************************************************************************/

void _ibc_alg_rand(uint8_t *r, uint32_t len) // static
{
    uint32_t i;

    for (i = 0; i < len; i++)
        *r++ = i;
}

static void _ibc_write_fpbytes(uint8_t *dst, uint32_t *src)
{
    int32_t bytelen;
    BNToByte(src, BNWordLen, dst, &bytelen);
}

static void _sm9_hash_h(uint32_t *pwH, uint8_t tag, const uint8_t * msg, uint32_t mlen, const uint8_t *add, uint32_t alen)
{
    SM9_Hash_Ctx ctx;
    uint8_t ct[4] = {0x00, 0x00, 0x00, 0x01};
    uint8_t Ha[2*BNByteLen];
    uint32_t pwHa[SM9_HWordLen + 1];
    uint32_t pwN1[BNWordLen];
    // U64 carry = 0;
    int32_t i;

    for (i = 0; i < 2; i++) {
        SM9_Hash_Init(&ctx);
        SM9_Hash_Update(&ctx, &tag, 1);
        SM9_Hash_Update(&ctx, msg, mlen);
        if (add)
            SM9_Hash_Update(&ctx, add, alen);
        SM9_Hash_Update(&ctx, ct, 4);
        SM9_Hash_Final(&ctx, Ha + i * BNByteLen);
        ct[3]++;
    }

    ByteToBN(Ha, SM9_HByteLen, pwHa, SM9_HWordLen);
    pwHa[SM9_HWordLen] = 0x00000000;

    // h=(Ha mod (n-1))+1
    bn_sub_int(pwN1, sm9_sys_para.EC_N, 1, BNWordLen);
    BN_Mod_Basic(pwH, BNWordLen, pwHa, SM9_HWordLen, pwN1, BNWordLen);
    bn_add_int(pwH, pwH, 1, BNWordLen);
}

static void SM9_Hash_H1(uint32_t *pwH, const uint8_t *msg, uint32_t mlen, const uint8_t *add, uint32_t alen)
{
    _sm9_hash_h(pwH, 0x01, msg, mlen, add, alen);
}

static void SM9_Hash_H2(uint32_t *pwH, const uint8_t *msg, uint32_t mlen, uint8_t *add, uint32_t alen)
{
    _sm9_hash_h(pwH, 0x02, msg, mlen, add, alen);
}

/******************************************************************************/

int32_t SM9_Alg_GetVersion()
{
    int32_t ret = 0;
    uint8_t a = 1;
    uint8_t b = 1;
    uint8_t c = 0;

#ifdef SM9_ALG_RND_ENABLE
    ret |= 0x01;
#endif // SM9_ALG_RND_ENABLE
#ifdef SM9_PAIR_ENABLE
    ret |= 0x02;
#endif // SM9_PAIR_ENABLE
#ifdef SM9_SIG_SYS_ENABLE
    ret |= 0x08;
#endif // SM9_SIG_SYS_ENABLE
#ifdef SM9_SIG_USR_ENABLE
    ret |= 0x04;
#endif // SM9_SIG_USR_ENABLE
#ifdef SM9_ENC_SYS_ENABLE
    ret |= 0x80;
#endif // SM9_ENC_SYS_ENABLE
#ifdef SM9_ENC_USR_ENABLE
    ret |= 0x40;
#endif // SM9_ENC_USR_ENABLE

    ret += (a << 24) + (b << 16) + (c << 8);

    return ret;
}

int32_t SM9_Alg_Pair(uint8_t *g, uint8_t *p1, uint8_t *p2)
{
#ifdef SM9_PAIR_ENABLE
    SM9_ECP_A Ecp_P1;
    SM9_ECP2_A Ecp2_P2;
    SM9_Fp12 Fp12_G;

    // Read ecpoint P1 and convert to MontMode
    SM9_Ecp_A_ReadBytes(&Ecp_P1, p1);
    // Read ecpoint P2 and convert to MontMode
    SM9_Ecp2_A_ReadBytes(&Ecp2_P2, p2);
    // Compute pairing
    SM9_Alg_Pair_Mont(&Fp12_G, &Ecp_P1, &Ecp2_P2);
    // Output to bytes
    SM9_Fp12_WriteBytes(g, &Fp12_G);

    return CRYPT_SUCCESS;
#else
    return CRYPT_SM9_ERR_NOT_SUPPORT;
#endif /* SM9_PAIR_ENABLE */
}

int32_t SM9_Get_Sig_G(uint8_t *g, uint8_t *mpk)
{
    return SM9_Alg_Pair(g, g_SM9_G1, mpk);
}

int32_t SM9_Get_Enc_G(uint8_t *g, uint8_t *mpk)
{
    return SM9_Alg_Pair(g, mpk, g_SM9_G2);
}

int32_t SM9_Alg_MSKG(uint8_t *ks, uint8_t *mpk)
{
#ifdef SM9_SIG_SYS_ENABLE
    uint32_t BN_SK[BNWordLen];
    SM9_ECP2_A Point_PK;

    // Read System PriKey to BN
    SM9_Bn_ReadBytes(BN_SK, ks);
    // Genarate System PubKey(in APoint MontMode)
    SM9_Ecp2_KP(&Point_PK, &sm9_sys_para.EC_Fp2_G_Mont, BN_SK);
    // Convert System PubKey to bytes(in NormMode)
    SM9_Ecp2_A_WriteBytes(mpk, &Point_PK);

    return CRYPT_SUCCESS;
#else
    return CRYPT_SM9_ERR_NOT_SUPPORT;
#endif /* SM9_SIG_SYS_ENABLE */
}

int32_t SM9_Alg_USKG(const uint8_t *id, uint32_t ilen, uint8_t *ks, uint8_t *ds)
{
#ifdef SM9_SIG_SYS_ENABLE
    uint32_t BN_t1[BNWordLen];
    uint32_t BN_t2[BNWordLen];
    SM9_ECP_A Ecp_ds;

    // Read System PriKey to BN
    SM9_Bn_ReadBytes(BN_t2, ks);

    // Compute t1 = H1(IDA||hid, N)+ks
    SM9_Hash_H1(BN_t1, id, ilen, SM9_HID_Sign, 1);
    SM9_Fn_Add(BN_t1, BN_t1, BN_t2);
    // Check t1 != 0
    if (SM9_Bn_IsZero(BN_t1))
        return CRYPT_SM9_ERR_KEY_ERR;
    // Compute t2 = ks*(t1^-1)
    BN_GetInv_Mont(BN_t1, BN_t1, sm9_sys_para.EC_N, sm9_sys_para.N_Mc, sm9_sys_para.N_R2, sm9_sys_para.wsize);
    bn_mont_mul(BN_t2, BN_t1, BN_t2, sm9_sys_para.EC_N, sm9_sys_para.N_Mc, sm9_sys_para.wsize);
    // Compute ds = [t2]P1
    SM9_Ecp_KP(&Ecp_ds, &sm9_sys_para.EC_Fp_G_Mont, BN_t2);

    // Convert User PriKey to bytes(in NormMode)
    SM9_Ecp_A_WriteBytes(ds, &Ecp_ds);

    return CRYPT_SUCCESS;
#else
    return CRYPT_SM9_ERR_NOT_SUPPORT;
#endif /* SM9_SIG_SYS_ENABLE */
}

#ifdef SM9_SIG_USR_ENABLE
static int32_t _sm9_alg_sign(
    const uint8_t *msg, uint32_t mlen,
    const uint8_t *ds, uint8_t *r,
    const uint8_t *g, const uint8_t *mpk,
    uint32_t *BN_h, SM9_ECP_A *Ecp_s)
{
    uint32_t BN_r[BNWordLen];
    uint8_t pbW[12 * BNByteLen];
    SM9_Fp12 Fp12_g;
    SM9_ECP2_A Ecp2_P;

    // Read Random number r(in NormMode), ensure r in [1, N-1]
    SM9_Bn_ReadBytes(BN_r, r);
    SM9_Fn_LastRes(BN_r);
    bn_add_int(BN_r, BN_r, 1, BNWordLen);
    SM9_Fn_LastRes(BN_r);

    // Get System Element g and compute g^r
    if (g) // If g is given
    {
        // Read g from input and convert to MontMode
        SM9_Fp12_ReadBytes(&Fp12_g, g);
        // w = g^r
        SM9_Fp12_Exp(&Fp12_g, &Fp12_g, BN_r);
    }
    else // If g is not given
    {
        // Read Ppub from input
        SM9_Ecp2_A_ReadBytes(&Ecp2_P, mpk);
        // Compute w = e(r*P1, Ppub)
        SM9_Ecp_KP(Ecp_s, &sm9_sys_para.EC_Fp_G_Mont, BN_r);
        SM9_Alg_Pair_Mont(&Fp12_g, Ecp_s, &Ecp2_P);
    }
    SM9_Fp12_WriteBytes(pbW, &Fp12_g);

    // h = H2(M||w, N)
    SM9_Hash_H2(BN_h, msg, mlen, pbW, 12 * BNByteLen);

    // l = (r-h) mod N (l should not be zero)
    SM9_Fn_Sub(BN_r, BN_r, BN_h);
    if (SM9_Bn_IsZero(BN_r))
        return CRYPT_SM9_ERR_SIGN_FAILED;
    // Read User Prikey and convert to MontMode
    SM9_Ecp_A_ReadBytes(Ecp_s, ds);
    // S = l * dsA
    SM9_Ecp_KP(Ecp_s, Ecp_s, BN_r);

    return CRYPT_SUCCESS;
}
#endif /* SM9_SIG_USR_ENABLE */

int32_t SM9_Sign(
    uint32_t opt,
    const uint8_t *msg,
    uint32_t mlen,
    const uint8_t *ds,
    uint8_t *r,
    const uint8_t *g,
    const uint8_t *mpk,
    uint8_t *sign,
    uint32_t *slen)
{
#ifdef SM9_SIG_USR_ENABLE
    // Signature buf
    uint32_t BN_h[BNWordLen];
    SM9_ECP_A Ecp_s;
    int32_t ret;
    int32_t len;

#ifndef SM9_ALG_RND_ENABLE
    uint8_t randBuf[BNByteLen];
    r = randBuf;
    _ibc_alg_rand(r, BNByteLen);
#endif // SM9_ALG_RND_ENABLE

    ret = _sm9_alg_sign(msg, mlen, ds, r, g, mpk, BN_h, &Ecp_s);
    if (ret != CRYPT_SUCCESS)
        return ret;

    // Output Signature to bytes
    _ibc_write_fpbytes(sign, BN_h);
    len = SM9_Fp_ECP_A_WriteBytesWithPC(sign + sm9_sys_para.wsize * WordByteLen, opt, &Ecp_s);
    if (len < 0)
        return CRYPT_SM9_ERR_BAD_INPUT;
    *slen = len + sm9_sys_para.wsize*WordByteLen;

    return CRYPT_SUCCESS;
#else
    return CRYPT_SM9_ERR_NOT_SUPPORT;
#endif /* SM9_SIG_USR_ENABLE */
}

int32_t SM9_Alg_Sign(const uint8_t *msg, uint32_t mlen,
    const uint8_t *ds, uint8_t *r,
    const uint8_t *g, const uint8_t *mpk,
    uint8_t *sign)
{
    uint32_t slen;
    return SM9_Sign(SM9_OPT_DM_MODE0, msg, mlen, ds, r, g, mpk, sign, &slen);
}

#ifdef SM9_SIG_USR_ENABLE
static int32_t _sm9_alg_vefiry(
    const uint8_t *msg, uint32_t mlen,
    const uint8_t *id, uint32_t ilen,
    const uint8_t *g, const uint8_t *mpk,
    uint32_t *BN_h, SM9_ECP_A *Ecp_s)
{
    SM9_Fp12 Fp12_g;
    SM9_Fp12 Fp12_u;
    uint8_t pbW[12 * BNByteLen];

    // uint32_t BN_h[BNWordLen];
    // SM9_ECP_A Ecp_s;
    uint32_t BN_h1[BNWordLen];
    SM9_ECP2_A Ecp2_P;

    SM9_ECP2_J Ecp2_J;

    SM9_ECP_J Ecp_J;
    SM9_ECP_A Ecp_A;

    // Read System PubKey to JPoint in MontMode
    SM9_Ecp2_A_ReadBytes(&Ecp2_P, mpk);

    // h1 = H1(IDA||hid, N)
    SM9_Hash_H1(BN_h1, id, ilen, SM9_HID_Sign, 1);

    // Get System Element g and convert to MontMode
    if (g) {
        // Read from input
        SM9_Fp12_ReadBytes(&Fp12_g, g);
        // t = g^h
        SM9_Fp12_Exp(&Fp12_g, &Fp12_g, BN_h);

        // P = [h1]P2 + Ppub-s
        SM9_Ecp2_A_ToJ(&Ecp2_J, &Ecp2_P);
        SM9_Ecp2_KP(&Ecp2_P, &sm9_sys_para.EC_Fp2_G_Mont, BN_h1);
        SM9_Ecp2_J_AddA(&Ecp2_J, &Ecp2_J, &Ecp2_P);
        SM9_Ecp2_J_ToA(&Ecp2_P, &Ecp2_J);
        // u = e(S', P)
        SM9_Alg_Pair_Mont(&Fp12_u, Ecp_s, &Ecp2_P);
    } else {
        // Compute by input
        // t = e([h]P1 + S, Ppub-s)
        SM9_Ecp_KP(&Ecp_A, &sm9_sys_para.EC_Fp_G_Mont, BN_h);
        SM9_Ecp_A_ToJ(&Ecp_J, &Ecp_A);
        SM9_Ecp_J_AddA(&Ecp_J, &Ecp_J, Ecp_s);
        SM9_Ecp_J_ToA(&Ecp_A, &Ecp_J);
        SM9_Alg_Pair_Mont(&Fp12_g, &Ecp_A, &Ecp2_P);

        // u = e([h1]S, P2)
        SM9_Ecp_KP(Ecp_s, Ecp_s, BN_h1);
        SM9_Alg_Pair_Mont(&Fp12_u, Ecp_s, &sm9_sys_para.EC_Fp2_G_Mont);
    }

    // w' = u * t
    SM9_Fp12_Mul(&Fp12_g, &Fp12_g, &Fp12_u);
    SM9_Fp12_WriteBytes(pbW, &Fp12_g);
    // h2 = H2(M'||w', N)
    SM9_Hash_H2(BN_h1, msg, mlen, pbW, 12 * BNByteLen);
    // Verify h2 ?= h
    if (bn_equal(BN_h1, BN_h, BNWordLen))
        return CRYPT_SUCCESS;
    return CRYPT_SM9_VERIFY_FAIL;
}
#endif /* SM9_SIG_USR_ENABLE */

int32_t SM9_Verify(
    uint32_t opt,
    const uint8_t *msg,
    uint32_t mlen,
    const uint8_t *id,
    uint32_t ilen,
    const uint8_t *g,
    const uint8_t *mpk,
    const uint8_t *sign,
    uint8_t slen)
{
    int32_t ret;
    uint32_t BN_h[BNWordLen];
    SM9_ECP_A Ecp_s;

    if (opt == SM9_OPT_DM_MODE1) {
        if (slen != SM9_SIGNATURE_BYTES + 1)
            return CRYPT_SM9_ERR_BAD_INPUT;
    } else {
        if (slen != SM9_SIGNATURE_BYTES)
            return CRYPT_SM9_ERR_BAD_INPUT;
    }

    // Read Signature(h,s) and convert s to MontMode
    SM9_Bn_ReadBytes(BN_h, sign);
    // h must be in [1, N-1]
    if (SM9_Bn_IsZero(BN_h) || bn_cmp(BN_h, sm9_sys_para.EC_N, sm9_sys_para.wsize) >= 0) {
        return CRYPT_SM9_VERIFY_FAIL;
    }
    sign += sm9_sys_para.wsize*WordByteLen;
    if (SM9_Fp_ECP_A_ReadBytesWithPC(&Ecp_s, opt, sign))
        return -1;

    ret = _sm9_alg_vefiry(msg, mlen, id, ilen, g, mpk, BN_h, &Ecp_s);

    return ret;
}

int32_t SM9_Alg_Verify(
    const uint8_t *msg,
    uint32_t mlen,
    const uint8_t *id,
    uint32_t ilen,
    const uint8_t *g,
    const uint8_t *mpk,
    const uint8_t *sign)
{
    return SM9_Verify(SM9_OPT_DM_MODE0, msg, mlen, id, ilen, g, mpk, sign, SM9_SIGNATURE_BYTES);
}

int32_t SM9_Alg_MEKG(uint8_t *ke, uint8_t *mpk)
{
#ifdef SM9_ENC_SYS_ENABLE
    uint32_t BN_SK[BNWordLen];
    SM9_ECP_A Point_PK;

    // Read System PriKey to BN
    SM9_Bn_ReadBytes(BN_SK, ke);
    // Genarate System PubKey(in APoint MontMode)
    SM9_Ecp_KP(&Point_PK, &sm9_sys_para.EC_Fp_G_Mont, BN_SK);
    // Convert System PubKey to bytes(in NormMode)
    SM9_Ecp_A_WriteBytes(mpk, &Point_PK);

    return CRYPT_SUCCESS;
#else
    return CRYPT_SM9_ERR_NOT_SUPPORT;
#endif /* SM9_ENC_SYS_ENABLE */
}

int32_t SM9_Alg_UEKG(const uint8_t *id, uint32_t ilen, uint8_t *ke, uint8_t *de)
{
#ifdef SM9_ENC_SYS_ENABLE
    uint32_t BN_t1[BNWordLen];
    uint32_t BN_t2[BNWordLen];
    SM9_ECP2_A Ecp2_ds;

    // Read System PriKey to BN
    SM9_Bn_ReadBytes(BN_t2, ke);

    // Compute t1 = H1(IDA||hid, N)+ks
    SM9_Hash_H1(BN_t1, id, ilen, SM9_HID_Enc, 1);
    bn_mod_add(BN_t1, BN_t1, BN_t2, sm9_sys_para.EC_N, sm9_sys_para.wsize);
    // Check t1 != 0
    if (bn_is_zero(BN_t1, sm9_sys_para.wsize))
        return CRYPT_SM9_ERR_KEY_ERR;
    // Compute t2 = ks*(t1^-1)
    BN_GetInv_Mont(BN_t1, BN_t1, sm9_sys_para.EC_N, sm9_sys_para.N_Mc, sm9_sys_para.N_R2, sm9_sys_para.wsize);
    bn_mont_mul(BN_t2, BN_t1, BN_t2, sm9_sys_para.EC_N, sm9_sys_para.N_Mc, sm9_sys_para.wsize);
    // Compute de = [t2]P1
    SM9_Ecp2_KP(&Ecp2_ds, &sm9_sys_para.EC_Fp2_G_Mont, BN_t2);

    // Convert User PriKey to bytes(in NormMode)
    SM9_Ecp2_A_WriteBytes(de, &Ecp2_ds);

    return CRYPT_SUCCESS;
#else
    return CRYPT_SM9_ERR_NOT_SUPPORT;
#endif /* SM9_ENC_SYS_ENABLE */
}

void SM9_Hash_KDF_Init(SM9_CTX *ctx, const uint8_t *C1, uint8_t *w, const uint8_t *id, uint32_t ilen)
{
    SM9_Hash_Init(&ctx->enc.xor_ctx);
    SM9_Hash_Update(&ctx->enc.xor_ctx, C1, 2 * BNByteLen);
    SM9_Hash_Update(&ctx->enc.xor_ctx, w, 12 * BNByteLen);
    SM9_Hash_Update(&ctx->enc.xor_ctx, id, ilen);
    ctx->enc.cnt[0] = ctx->enc.cnt[1] = ctx->enc.cnt[2] = 0;
    ctx->enc.cnt[3] = 1;
}

void SM9_Hash_KDF_Block(SM9_CTX *ctx, uint32_t cnt, uint8_t *k)
{
    SM9_Hash_Ctx tmp_ctx;
    uint8_t c[4];

    c[0] = (cnt & 0xFF000000) >> 24;
    c[1] = (cnt & 0x00FF0000) >> 16;
    c[2] = (cnt & 0x0000FF00) >> 8;
    c[3] = cnt & 0x000000FF;
    memcpy(&tmp_ctx, &ctx->enc.xor_ctx, sizeof(SM9_Hash_Ctx));

    SM9_Hash_Update(&tmp_ctx, c, 4);
    SM9_Hash_Final(&tmp_ctx, k);
}

static void _sm9_enc_init(SM9_CTX *ctx, const uint8_t *id, uint32_t ilen, uint8_t *r,
    const uint8_t *g, const uint8_t *mpk, uint8_t *C1)
{
    uint32_t BN_r[BNWordLen];
    uint32_t BN_h[BNWordLen];
    SM9_ECP_A Ecp_P;
    SM9_ECP_A Ecp_T;
    SM9_Fp12 Fp12_g;
    uint8_t pbW[12 * BNByteLen];

    // Read Random number r(in NormMode), ensure r in [1, N-1]
    SM9_Bn_ReadBytes(BN_r, r);
    SM9_Fn_LastRes(BN_r);
    bn_add_int(BN_r, BN_r, 1, BNWordLen);
    SM9_Fn_LastRes(BN_r);
    // Read System PubKey to JPoint in MontMode
    SM9_Ecp_A_ReadBytes(&Ecp_P, mpk);

    // Get System Element g and convert to MontMode
    if (g) {
        // Read from input
        SM9_Fp12_ReadBytes(&Fp12_g, g);
        // w = g^r
        SM9_Fp12_Exp(&Fp12_g, &Fp12_g, BN_r);
    } else {
        // Compute by input
        SM9_Ecp_KP(&Ecp_T, &Ecp_P, BN_r);
        SM9_Alg_Pair_Mont(&Fp12_g, &Ecp_T, &sm9_sys_para.EC_Fp2_G_Mont);
    }
    SM9_Fp12_WriteBytes(pbW, &Fp12_g);

    // h1 = H1(IDB||hid, N)
    SM9_Hash_H1(BN_h, id, ilen, SM9_HID_Enc, 1);

    // QB=[h1]P1+Ppub-e
    SM9_Fp_ECP_KPAddAToA(&Ecp_P, &sm9_sys_para.EC_Fp_G_Mont, BN_h, &Ecp_P, &sm9_sys_para);

    // C1=[r]QB
    SM9_Ecp_KP(&Ecp_P, &Ecp_P, BN_r);
    SM9_Ecp_A_WriteBytes(C1, &Ecp_P);

    SM9_Hash_KDF_Init(ctx, C1, pbW, id, ilen);
    ctx->enc.bytes = 0;
}

// Key derivation function for public key encryption
// Mode-1: msg is not null, generate key stream and xor msg to enc
// Mode-2: msg is null and mlen is null, generate key stream with ctx's bytes record to enc
// Mode-3: msg is null and mlen is not null, generate key stream with mlen to enc
static void _sm9_pke_kdf(SM9_CTX *ctx, const uint8_t *msg, uint32_t mlen, uint8_t *enc)
{
    uint8_t k[2*SM9_Hash_Size];
    uint32_t i;
    uint32_t cnt;
    uint32_t res;

    cnt = ctx->enc.bytes / SM9_Hash_Size + 1;
    res = ctx->enc.bytes % SM9_Hash_Size;

    if (!msg) {
        if (mlen) {
            cnt = mlen / SM9_Hash_Size + 1;
            res = mlen % SM9_Hash_Size;
        }
        SM9_Hash_KDF_Block(ctx, cnt++, k);
        if (res) {
            SM9_Hash_KDF_Block(ctx, cnt++, k + SM9_Hash_Size);
        }
        memcpy(enc, k + res, SM9_Hash_Size);
        return;
    }

    if (res) {
        SM9_Hash_KDF_Block(ctx, cnt++, k);
        SM9_Xor(&enc[res], SM9_Hash_Size - res, msg, k);
        enc += SM9_Hash_Size - res;
        msg += SM9_Hash_Size - res;
        mlen -= SM9_Hash_Size - res;
        ctx->enc.bytes += SM9_Hash_Size - res;
    }

    for (i = 1; i <= mlen / SM9_Hash_Size; i++) {
        SM9_Hash_KDF_Block(ctx, cnt++, k);
        SM9_Xor(enc, SM9_Hash_Size, msg, k);
        enc += SM9_Hash_Size;
        msg += SM9_Hash_Size;
    }
    if ((res = mlen % SM9_Hash_Size) != 0) {
        SM9_Hash_KDF_Block(ctx, cnt++, k);
        SM9_Xor(enc, res, msg, k);
    }
    ctx->enc.bytes += mlen;
}

void SM9_Mac_Init(SM9_CTX *ctx)
{
    SM9_Hash_Init(&ctx->mac_ctx);
}

void SM9_Mac_Update(SM9_CTX *ctx, const uint8_t *msg, uint32_t mlen)
{
    SM9_Hash_Update(&ctx->mac_ctx, msg, mlen);
}

void SM9_Mac_Final(SM9_CTX *ctx, uint8_t *key, uint32_t klen, uint8_t *mac)
{
    SM9_Hash_Update(&ctx->mac_ctx, key, klen);
    SM9_Hash_Final(&ctx->mac_ctx, mac);
}

int32_t SM9_Alg_Enc(const uint8_t *msg, uint32_t mlen,
                    const uint8_t *id, uint32_t ilen, uint8_t *r,
                    const uint8_t *g, const uint8_t *mpk,
                    uint8_t *enc, uint32_t *elen)
{
#ifdef SM9_ENC_USR_ENABLE
    SM9_CTX sm9_ctx;
    uint8_t *C1;
    uint8_t *C2;
    uint8_t *C3;
    uint8_t mkey[SM9_Hash_Size];

#ifndef SM9_ALG_RND_ENABLE
    uint8_t randBuf[BNByteLen];
    r = randBuf;
    _ibc_alg_rand(r, BNByteLen);
#endif // SM9_ALG_RND_ENABLE

//#define SM9_ENC_ROUND_BEGIN        for(cnt = 0; cnt < 10000; cnt++)  {
//#define SM9_ENC_ROUND_END        }

#define SM9_ENC_ROUND_BEGIN
#define SM9_ENC_ROUND_END

    C1 = enc;
    C3 = C1 + SM9_C1_ByteLen;
    C2 = C3 + SM9_C3_ByteLen;

    if ((C2 > msg) && (C2 < msg + mlen))
        return CRYPT_SM9_ERR_BAD_INPUT;

    _sm9_enc_init(&sm9_ctx, id, ilen, r, g, mpk, C1);
    SM9_Mac_Init(&sm9_ctx);

    _sm9_pke_kdf(&sm9_ctx, msg, mlen, C2);

    SM9_Mac_Update(&sm9_ctx, C2, mlen);

    // K1 is all zeros if C2 == msg after XOR, must reject per GM/T 0044-2016 A6 a.1
    // ConstTimeMemcmp returns non-zero when equal
    if (mlen > 0 && ConstTimeMemcmp(C2, msg, mlen) != 0) {
        BSL_SAL_CleanseData(mkey, SM9_Hash_Size);
        return CRYPT_SM9_ERR_ENCRYPT_FAILED;
    }

    _sm9_pke_kdf(&sm9_ctx, 0, 0, mkey);
    SM9_Mac_Final(&sm9_ctx, mkey, SM9_Hash_Size, C3);

    if (elen)
        *elen = mlen + SM9_C1_ByteLen + SM9_C3_ByteLen;

    BSL_SAL_CleanseData(mkey, SM9_Hash_Size);
    return CRYPT_SUCCESS;

#else
    return CRYPT_SM9_ERR_NOT_SUPPORT;
#endif /* SM9_ENC_USR_ENABLE */
}

int32_t SM9_Dec_Init(SM9_CTX *ctx, const uint8_t *de, const uint8_t *id, uint32_t ilen, const uint8_t *C1)
{
    SM9_ECP_A Ecp_P;
    SM9_ECP2_A Ecp2_D;
    SM9_Fp12 Fp12_g;
    uint8_t pbW[12 * BNByteLen];

    // Read User Prikey and convert to MontMode
    SM9_Ecp2_A_ReadBytes(&Ecp2_D, de);
    // Read Cipher Part1(in APoint MontMode)
    SM9_Ecp_A_ReadBytes(&Ecp_P, C1);

    // Check C1 is a point
    if (SM9_Ecp_A_Check(&Ecp_P))
        return CRYPT_SM9_ERR_DECRYPT_FAILED;

    // w'=e(C1, de)
    SM9_Alg_Pair_Mont(&Fp12_g, &Ecp_P, &Ecp2_D);
    SM9_Fp12_WriteBytes(pbW, &Fp12_g);

    SM9_Hash_KDF_Init(ctx, C1, pbW, id, ilen);
    ctx->enc.bytes = 0;

    return CRYPT_SUCCESS;
}

int32_t SM9_Alg_Dec(const uint8_t *enc, uint32_t elen,
                    const uint8_t *de, const uint8_t *id, uint32_t ilen,
                    uint8_t *msg, uint32_t *mlen)
{
#ifdef SM9_ENC_USR_ENABLE
    SM9_CTX sm9_ctx;
    const uint8_t *C1;
    const uint8_t *C2;
    const uint8_t *C3;
    uint8_t k[2 * SM9_Hash_Size];
    uint8_t mac[SM9_C3_ByteLen];
    uint32_t len;
    int32_t ret;

    if (elen < SM9_C1_ByteLen + SM9_C3_ByteLen)
        return CRYPT_SM9_ERR_BAD_INPUT;
    len = elen - SM9_C1_ByteLen - SM9_C3_ByteLen;

    C1 = enc;
    C3 = C1 + SM9_C1_ByteLen;
    C2 = C3 + SM9_C3_ByteLen;

    ret = SM9_Dec_Init(&sm9_ctx, de, id, ilen, C1);
    if (ret != CRYPT_SUCCESS)
        return ret;
    SM9_Mac_Init(&sm9_ctx);
    SM9_Mac_Update(&sm9_ctx, C2, len);
    _sm9_pke_kdf(&sm9_ctx, 0, len, k);
    SM9_Mac_Final(&sm9_ctx, k, SM9_Hash_Size, mac);
     // Compute MAC(K2', C2) and Compare to C3

    if (ConstTimeMemcmp(mac, C3, SM9_C3_ByteLen) == 0) {
        return CRYPT_SM9_ERR_DECRYPT_FAILED;
    }
    _sm9_pke_kdf(&sm9_ctx, C2, len, msg);

    if (mlen)
        *mlen = elen - SM9_C1_ByteLen - SM9_C3_ByteLen;

    return CRYPT_SUCCESS;
#else
    return CRYPT_SM9_ERR_NOT_SUPPORT;
#endif /* SM9_ENC_USR_ENABLE */
}

#endif // HITLS_CRYPTO_SM9