* This file is part of the openHiTLS project.
*
* openHiTLS is licensed under the Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
*
* http://license.coscl.org.cn/MulanPSL2
*
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
*/
#ifndef CRYPT_ML_KEM_LOCAL_H
#define CRYPT_ML_KEM_LOCAL_H
#include "crypt_mlkem.h"
#include "sal_atomic.h"
#include "crypt_local_types.h"
#define BITS_OF_BYTE 8
#define MLKEM_ETA1_MAX 3
#define MLKEM_ETA2_MAX 2
#define MLKEM_N 256
#define MLKEM_N_HALF 128
#define MLKEM_CIPHER_LEN 384
#define MLKEM_SEED_LEN 32
#define MLKEM_SHARED_KEY_LEN 32
#define MLKEM_PRF_BLOCKSIZE 64
#define MLKEM_ENCODE_BLOCKSIZE 32
#define MLKEM_Q 3329
#define MLKEM_Q_INV_BETA (-3327)
#define MLKEM_Q_HALF ((MLKEM_Q + 1) / 2)
#define MLKEM_BITS_OF_Q 12
#define MLKEM_INVN 3303
#define MLKEM_K_MAX 4
#define MLKEM_PLANTARD_L 16
#define MLKEM_PLANTARD_ALPHA 3
#define MLKEM_LAST_ROUND_ZETA 2131356556
#define MLKEM_HALF_DEGREE_INVERSE_MOD_Q (-33544352)
typedef int32_t (*MlKemHashFunc)(uint32_t id, const uint8_t *in, uint32_t inLen, uint8_t *out, uint32_t *outLen);
static inline int16_t BarrettReduction(int32_t a)
{
const int32_t v = ((1 << 27) + MLKEM_Q / 2) / MLKEM_Q;
int32_t t = ((int64_t)v * a + (1 << 26)) >> 27;
t *= MLKEM_Q;
return (int16_t)(a - t);
}
static inline int16_t PlantardReduction(int32_t a)
{
int32_t tmp = a;
tmp >>= MLKEM_PLANTARD_L;
tmp = (tmp + (1 << MLKEM_PLANTARD_ALPHA)) * MLKEM_Q;
tmp >>= MLKEM_PLANTARD_L;
return tmp;
}
typedef struct {
int16_t *bufAddr;
int16_t *matrix[MLKEM_K_MAX][MLKEM_K_MAX];
int16_t *vectorS[MLKEM_K_MAX];
int16_t *vectorE[MLKEM_K_MAX];
int16_t *vectorT[MLKEM_K_MAX];
} MLKEM_MatrixSt;
typedef struct {
int32_t paramId;
uint8_t k;
uint8_t eta1;
uint8_t eta2;
uint8_t du;
uint8_t dv;
uint32_t secBits;
uint32_t encapsKeyLen;
uint32_t decapsKeyLen;
uint32_t cipherLen;
uint32_t sharedLen;
uint32_t bits;
} CRYPT_MlKemInfo;
struct CryptMlKemCtx {
int32_t algId;
const CRYPT_MlKemInfo *info;
uint8_t *ek;
uint32_t ekLen;
uint8_t *dk;
uint32_t dkLen;
BSL_SAL_RefCount references;
void *libCtx;
MLKEM_MatrixSt keyData;
CRYPT_ALGO_MLKEM_DK_FORMAT_TYPE dkFormat;
bool hasSeed;
uint8_t seed[MLKEM_SEED_LEN * 2];
};
int32_t MLKEM_DecodeDk(CRYPT_ML_KEM_Ctx *ctx, const uint8_t *dk, uint32_t dkLen);
int32_t MLKEM_DecodeEk(CRYPT_ML_KEM_Ctx *ctx, const uint8_t *ek, uint32_t ekLen);
void MLKEM_ComputNTT(int16_t *a, const int32_t *psi);
void MLKEM_ComputINTT(int16_t *a, const int32_t *psi);
void MLKEM_SamplePolyCBD(int16_t *polyF, uint8_t *buf, uint8_t eta);
void MLKEM_TransposeMatrixMulAdd(uint8_t k, int16_t **matrix, int16_t **polyVec, int16_t **polyVecOut,
const int16_t mulCache[MLKEM_K_MAX][MLKEM_N_HALF]);
void MLKEM_MatrixMulAdd(uint8_t k, int16_t **matrix, int16_t **polyVec, int16_t **polyVecOut,
const int16_t mulCache[MLKEM_K_MAX][MLKEM_N_HALF]);
void MLKEM_VectorInnerProductAdd(uint8_t k, int16_t **polyVec1, int16_t **polyVec2, int16_t *polyOut,
const int32_t *factor);
void MLKEM_VectorInnerProductAddUseCache(uint8_t k, int16_t **polyVec1, int16_t **polyVec2, int16_t *polyOut,
const int16_t mulCache[MLKEM_K_MAX][MLKEM_N_HALF]);
void MLKEM_ComputeMulCache(uint8_t k, int16_t **input, int16_t output[MLKEM_K_MAX][MLKEM_N_HALF],
const int32_t *factor);
int32_t MLKEM_KeyGenInternal(CRYPT_ML_KEM_Ctx *ctx, uint8_t *d, uint8_t *z);
int32_t MLKEM_EncapsInternal(CRYPT_ML_KEM_Ctx *ctx, uint8_t *ct, uint32_t *ctLen, uint8_t *sk, uint32_t *skLen,
uint8_t *m);
int32_t MLKEM_DecapsInternal(CRYPT_ML_KEM_Ctx *ctx, uint8_t *ct, uint32_t ctLen, uint8_t *sk, uint32_t *skLen);
int32_t MLKEM_CreateMatrixBuf(uint8_t k, MLKEM_MatrixSt *st);
int32_t MLKEM_PKEGen(CRYPT_ML_KEM_Ctx *ctx, uint8_t *digest, uint8_t *pk, uint8_t *dk);
int32_t MLKEM_PKEEnc(uint32_t k, MLKEM_MatrixSt *mat, uint8_t du, uint8_t dv, uint8_t *ct,
int16_t *y[], int16_t *e1[], int16_t *u[],
int16_t *e2, int16_t *mu, int16_t *c2);
int32_t MLKEM_PKEDec(uint32_t k, MLKEM_MatrixSt *mat, int16_t *m, int16_t *c1[], int16_t *c2, uint8_t *result);
void ByteEncode(uint8_t *out, int16_t *in, uint8_t bits);
int32_t GenMatrix(const CRYPT_ML_KEM_Ctx *ctx, const uint8_t *seed,
int16_t *polyMatrix[MLKEM_K_MAX][MLKEM_K_MAX], bool isEnc);
int32_t SampleEta2(const CRYPT_ML_KEM_Ctx *ctx, uint8_t *seed, int16_t *s[], int16_t *e[]);
#endif