* This file is part of the openHiTLS project.
*
* openHiTLS is licensed under the Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
*
* http://license.coscl.org.cn/MulanPSL2
*
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
*/
#include "hitls_build.h"
#ifdef HITLS_CRYPTO_MLKEM
#include <string.h>
#include "crypt_utils.h"
#include "crypt_sha3.h"
#include "crypt_errno.h"
#include "asm_sha3.h"
#include "ml_kem_local.h"
#define GEN_MATRIX_NBLOCKS \
((12 * MLKEM_N / 8 * (1 << 12) / MLKEM_Q + CRYPT_SHAKE128_BLOCKSIZE) / CRYPT_SHAKE128_BLOCKSIZE)
void MLKEMPointMulExtended(int16_t *, int16_t *, const int16_t *);
void MLKEMAsymMulMont(uint32_t k, int16_t *r, int16_t *a, int16_t *b, int16_t *c);
void MLKEMAsymMul(uint32_t k, int16_t *r, int16_t **a, int16_t *b, int16_t *c);
void MLKEMPolyAddReduce(uint32_t k, int16_t **r, int16_t **e);
void MLKEMAdd2Reduce(int16_t *, int16_t *, const int16_t *);
void MLKEMSubReduce(int16_t *, int16_t *);
void MLKEMPolyVecReduce(uint32_t k, int16_t **r);
void PolyVectoBytes(uint32_t k, uint8_t *r, int16_t **a);
void MLKEMPolyINTTtoMont(uint32_t k, int16_t **r);
void MLKEMPolyCBDEta(uint32_t eta, int16_t *r, const uint8_t *buf);
void PolyVecCompress(uint32_t k, uint8_t *r, int16_t **a);
void PolyCompress(uint32_t k, uint8_t *r, int16_t *a);
void PackCipherText(uint32_t k, uint32_t c1len, uint8_t *ct, int16_t **t, int16_t *c2);
void MLKEMPolyCompress4(uint8_t *, int16_t *);
void MLKEMPolyCompress5(uint8_t *, int16_t *);
void MLKEMPolyCompress1(uint8_t *, int16_t *);
unsigned int MLKEMRejUniform(int16_t *, const uint8_t *);
void MLKEM_ComputeNTTAsm(int16_t *a);
void MLKEM_ComputeINTTAsm(int16_t *a);
* Base multiplication (basemul) twiddle factor table.
*
* X^256 + 1 decomposes into 64 degree-4 factors, each splitting into two
* degree-2 factors:
* (X^2 - zeta_j) * (X^2 + zeta_j), zeta_j = omega^(2*bitrev(j,6)+1) mod^{+-} Q
* for j = 0..63. The basemul needs both zeta_j and its Barrett companion zeta_hi_j.
*
* Generation script (verified against this table):
*
* # Uses Q, omega, bitrev, mod_pm, barrett_hi defined above.
* def gen_basemul_table(): # -> 256 entries (MLKEM_N)
* T = []
* for row in range(16): # 16 rows of 16 entries
* roots = []
* for m in range(4): # 4 twiddles per row
* j = 4 * row + m
* k = 2 * bitrev(j, 6) + 1
* z = mod_pm(pow(omega, k, Q))
* roots.append((z, barrett_hi(z)))
* for z, _ in roots: T += [z, -z] # root part
* for _, h in roots: T += [h, -h] # root_hi part
* return T
*/
const __attribute__((aligned(16))) int16_t MLKEM_BASEMUL_TWIDDLE_TABLE[MLKEM_N] = {
17, -17, -568, 568, 583, -583, -680, 680, 167, -167, -5591, 5591, 5739, -5739,
-6693, 6693, 1637, -1637, 723, -723, -1041, 1041, 1100, -1100, 16113, -16113, 7117, -7117,
-10247, 10247, 10828, -10828, 1409, -1409, -667, 667, -48, 48, 233, -233, 13869, -13869,
-6565, 6565, -472, 472, 2293, -2293, 756, -756, -1173, 1173, -314, 314, -279, 279,
7441, -7441, -11546, 11546, -3091, 3091, -2746, 2746, -1626, 1626, 1651, -1651, -540, 540,
-1540, 1540, -16005, 16005, 16251, -16251, -5315, 5315, -15159, 15159, -1482, 1482, 952, -952,
1461, -1461, -642, 642, -14588, 14588, 9371, -9371, 14381, -14381, -6319, 6319, 939, -939,
-1021, 1021, -892, 892, -941, 941, 9243, -9243, -10050, 10050, -8780, 8780, -9262, 9262,
733, -733, -992, 992, 268, -268, 641, -641, 7215, -7215, -9764, 9764, 2638, -2638,
6309, -6309, 1584, -1584, -1031, 1031, -1292, 1292, -109, 109, 15592, -15592, -10148, 10148,
-12717, 12717, -1073, 1073, 375, -375, -780, 780, -1239, 1239, 1645, -1645, 3691, -3691,
-7678, 7678, -12196, 12196, 16192, -16192, 1063, -1063, 319, -319, -556, 556, 757, -757,
10463, -10463, 3140, -3140, -5473, 5473, 7451, -7451, -1230, 1230, 561, -561, -863, 863,
-735, 735, -12107, 12107, 5522, -5522, -8495, 8495, -7235, 7235, -525, 525, 1092, -1092,
403, -403, 1026, -1026, -5168, 5168, 10749, -10749, 3967, -3967, 10099, -10099, 1143, -1143,
-1179, 1179, -554, 554, 886, -886, 11251, -11251, -11605, 11605, -5453, 5453, 8721, -8721,
-1607, 1607, 1212, -1212, -1455, 1455, 1029, -1029, -15818, 15818, 11930, -11930, -14322, 14322,
10129, -10129, -1219, 1219, -394, 394, 885, -885, -1175, 1175, -11999, 11999, -3878, 3878,
8711, -8711, -11566, 11566};
void KyberShake256x2Prf(uint8_t *out1, uint8_t *out2, size_t outlen, const uint8_t key[MLKEM_SEED_LEN], uint8_t nonce1,
uint8_t nonce2)
{
uint8_t extkey1[MLKEM_SEED_LEN + 1];
uint8_t extkey2[MLKEM_SEED_LEN + 1];
memcpy(extkey1, key, MLKEM_SEED_LEN);
memcpy(extkey2, key, MLKEM_SEED_LEN);
extkey1[MLKEM_SEED_LEN] = nonce1;
extkey2[MLKEM_SEED_LEN] = nonce2;
Shake256x2(out1, out2, outlen, extkey1, extkey2, sizeof(extkey1));
BSL_SAL_CleanseData(extkey1, MLKEM_SEED_LEN + 1);
BSL_SAL_CleanseData(extkey2, MLKEM_SEED_LEN + 1);
}
void KyberShake128x2Absorb(Keccakx2State state, const uint8_t seed[MLKEM_SEED_LEN], uint8_t x1, uint8_t x2, uint8_t y1,
uint8_t y2)
{
uint8_t extseed1[MLKEM_SEED_LEN + 2 + 6];
uint8_t extseed2[MLKEM_SEED_LEN + 2 + 6];
memcpy(extseed1, seed, MLKEM_SEED_LEN);
memcpy(extseed2, seed, MLKEM_SEED_LEN);
extseed1[MLKEM_SEED_LEN] = x1;
extseed1[MLKEM_SEED_LEN + 1] = y1;
extseed2[MLKEM_SEED_LEN] = x2;
extseed2[MLKEM_SEED_LEN + 1] = y2;
Keccakx2Absorb(state, CRYPT_SHAKE128_BLOCKSIZE, extseed1, extseed2, MLKEM_SEED_LEN + 2, 0x1F);
}
void KyberShakeAbsorb(ShakeState *state, const uint8_t seed[MLKEM_SEED_LEN], uint8_t x, uint8_t y)
{
uint8_t extseed[MLKEM_SEED_LEN + 2];
memcpy(extseed, seed, MLKEM_SEED_LEN);
extseed[MLKEM_SEED_LEN] = x;
extseed[MLKEM_SEED_LEN + 1] = y;
KeccakAbsorb(state->s, CRYPT_SHAKE128_BLOCKSIZE, extseed, sizeof(extseed), 0x1F);
state->pos = CRYPT_SHAKE128_BLOCKSIZE;
}
void PolyGetNoiseEtaX2(uint32_t eta, int16_t vec1[MLKEM_N], int16_t vec2[MLKEM_N], const uint8_t seed[MLKEM_SEED_LEN],
uint8_t nonce1, uint8_t nonce2)
{
uint8_t buf1[eta * MLKEM_N / 4], buf2[eta * MLKEM_N / 4];
KyberShake256x2Prf(buf1, buf2, sizeof(buf1), seed, nonce1, nonce2);
MLKEMPolyCBDEta(eta, vec1, buf1);
MLKEMPolyCBDEta(eta, vec2, buf2);
}
int32_t SampleEta1(const CRYPT_ML_KEM_Ctx *ctx, uint8_t *seed, int16_t *s[], int16_t *e[])
{
uint32_t k = ctx->info->k;
uint32_t eta1 = ctx->info->eta1;
if (k == 2) {
PolyGetNoiseEtaX2(eta1, s[0], s[1], seed, 0, 1);
PolyGetNoiseEtaX2(eta1, e[0], e[1], seed, 2, 3);
} else if (k == 3) {
PolyGetNoiseEtaX2(eta1, s[0], s[1], seed, 0, 1);
PolyGetNoiseEtaX2(eta1, s[2], e[0], seed, 2, 3);
PolyGetNoiseEtaX2(eta1, e[1], e[2], seed, 4, 5);
} else if (k == 4) {
PolyGetNoiseEtaX2(eta1, s[0], s[1], seed, 0, 1);
PolyGetNoiseEtaX2(eta1, s[2], s[3], seed, 2, 3);
PolyGetNoiseEtaX2(eta1, e[0], e[1], seed, 4, 5);
PolyGetNoiseEtaX2(eta1, e[2], e[3], seed, 6, 7);
}
for (uint32_t i = 0; i < k; i++) {
MLKEM_ComputeNTTAsm(s[i]);
MLKEM_ComputeNTTAsm(e[i]);
}
return CRYPT_SUCCESS;
}
int32_t SampleEta2(const CRYPT_ML_KEM_Ctx *ctx, uint8_t *seed, int16_t *s[], int16_t *e[])
{
uint32_t k = ctx->info->k;
uint32_t eta1 = ctx->info->eta1;
uint32_t eta2 = ctx->info->eta2;
if (k == 2) {
PolyGetNoiseEtaX2(eta1, s[0], s[1], seed, 0, 1);
PolyGetNoiseEtaX2(eta2, e[0], e[1], seed, 2, 3);
seed[MLKEM_SEED_LEN] = 4;
} else if (k == 3) {
PolyGetNoiseEtaX2(eta1, s[0], s[1], seed, 0, 1);
PolyGetNoiseEtaX2(eta1, s[2], e[0], seed, 2, 3);
PolyGetNoiseEtaX2(eta1, e[1], e[2], seed, 4, 5);
seed[MLKEM_SEED_LEN] = 6;
} else if (k == 4) {
PolyGetNoiseEtaX2(eta1, s[0], s[1], seed, 0, 1);
PolyGetNoiseEtaX2(eta1, s[2], s[3], seed, 2, 3);
PolyGetNoiseEtaX2(eta1, e[0], e[1], seed, 4, 5);
PolyGetNoiseEtaX2(eta1, e[2], e[3], seed, 6, 7);
seed[MLKEM_SEED_LEN] = 8;
}
for (uint32_t i = 0; i < k; i++) {
MLKEM_ComputeNTTAsm(s[i]);
}
return CRYPT_SUCCESS;
}
static unsigned int Parse(int16_t *poly, uint8_t *arrayB, uint32_t arrayLen, uint32_t n)
{
uint32_t i = 0;
uint32_t j = 0;
while (j < n) {
if (i + 3 > arrayLen) {
return j;
}
uint16_t d1 = ((uint16_t)arrayB[i]) + (((uint16_t)(arrayB[i + 1] & 0x0F)) << 8);
uint16_t d2 = (((uint16_t)arrayB[i + 1]) >> 4) + (((uint16_t)arrayB[i + 2]) << 4);
i += 3;
if (d1 < MLKEM_Q) {
poly[j++] = (int16_t)d1;
}
if (j < n && d2 < MLKEM_Q) {
poly[j++] = (int16_t)d2;
}
}
return j;
}
static void GenMatrixX2Pair(const uint8_t *seed, int16_t *poly0, int16_t *poly1,
uint8_t x1, uint8_t x2, uint8_t y1, uint8_t y2)
{
uint8_t buf0[GEN_MATRIX_NBLOCKS * CRYPT_SHAKE128_BLOCKSIZE + 2];
uint8_t buf1[GEN_MATRIX_NBLOCKS * CRYPT_SHAKE128_BLOCKSIZE + 2];
Keccakx2State state;
KyberShake128x2Absorb(state, seed, x1, x2, y1, y2);
Keccakx2Squeeze(buf0, buf1, GEN_MATRIX_NBLOCKS, CRYPT_SHAKE128_BLOCKSIZE, state);
uint32_t buflen = GEN_MATRIX_NBLOCKS * CRYPT_SHAKE128_BLOCKSIZE;
uint32_t ctr0 = MLKEMRejUniform(poly0, buf0);
uint32_t ctr1 = MLKEMRejUniform(poly1, buf1);
while (ctr0 < MLKEM_N || ctr1 < MLKEM_N) {
uint32_t off = buflen % 3;
for (uint32_t m = 0; m < off; m++) {
buf0[m] = buf0[buflen - off + m];
buf1[m] = buf1[buflen - off + m];
}
Keccakx2Squeeze(buf0 + off, buf1 + off, 1, CRYPT_SHAKE128_BLOCKSIZE, state);
buflen = off + CRYPT_SHAKE128_BLOCKSIZE;
ctr0 += Parse(poly0 + ctr0, buf0, buflen, MLKEM_N - ctr0);
ctr1 += Parse(poly1 + ctr1, buf1, buflen, MLKEM_N - ctr1);
}
}
static void GenMatrixSingle(const uint8_t *seed, int16_t *poly, uint8_t x, uint8_t y)
{
uint8_t buf[GEN_MATRIX_NBLOCKS * CRYPT_SHAKE128_BLOCKSIZE + 2];
ShakeState state;
KyberShakeAbsorb(&state, seed, x, y);
KeccakSqueeze(buf, GEN_MATRIX_NBLOCKS, state.s, CRYPT_SHAKE128_BLOCKSIZE);
uint32_t buflen = GEN_MATRIX_NBLOCKS * CRYPT_SHAKE128_BLOCKSIZE;
uint32_t ctr = MLKEMRejUniform(poly, buf);
while (ctr < MLKEM_N) {
uint32_t off = buflen % 3;
for (uint32_t m = 0; m < off; m++) {
buf[m] = buf[buflen - off + m];
}
KeccakSqueeze(buf + off, 1, state.s, CRYPT_SHAKE128_BLOCKSIZE);
buflen = off + CRYPT_SHAKE128_BLOCKSIZE;
ctr += Parse(poly + ctr, buf, buflen, MLKEM_N - ctr);
}
}
int32_t GenMatrix(const CRYPT_ML_KEM_Ctx *ctx, const uint8_t *seed, int16_t *polyMatrix[MLKEM_K_MAX][MLKEM_K_MAX],
bool isEnc)
{
uint8_t k = ctx->info->k;
if (k == 2 || k == 4) {
for (uint32_t i = 0; i < k; i++) {
for (uint32_t j = 0; j < k; j += 2) {
uint8_t x1 = isEnc ? i : j;
uint8_t x2 = isEnc ? i : j + 1;
uint8_t y1 = isEnc ? j : i;
uint8_t y2 = isEnc ? j + 1 : i;
GenMatrixX2Pair(seed, polyMatrix[i][j], polyMatrix[i][j + 1], x1, x2, y1, y2);
}
}
} else if (k == 3) {
for (uint32_t idx = 0; idx < 8; idx += 2) {
uint32_t i1 = idx / 3, j1 = idx % 3;
uint32_t i2 = (idx + 1) / 3, j2 = (idx + 1) % 3;
uint8_t x1 = isEnc ? i1 : j1;
uint8_t x2 = isEnc ? i2 : j2;
uint8_t y1 = isEnc ? j1 : i1;
uint8_t y2 = isEnc ? j2 : i2;
GenMatrixX2Pair(seed, polyMatrix[i1][j1], polyMatrix[i2][j2], x1, x2, y1, y2);
}
GenMatrixSingle(seed, polyMatrix[2][2], 2, 2);
}
return CRYPT_SUCCESS;
}
int32_t MLKEM_PKEGen(CRYPT_ML_KEM_Ctx *ctx, uint8_t *digest, uint8_t *pk, uint8_t *dk)
{
int32_t ret = CRYPT_SUCCESS;
uint8_t k = ctx->info->k;
uint8_t *p = digest;
uint8_t *q = digest + CRYPT_SHA3_512_DIGESTSIZE / 2;
GOTO_ERR_IF(GenMatrix(ctx, p, ctx->keyData.matrix, false), ret);
GOTO_ERR_IF(SampleEta1(ctx, q, ctx->keyData.vectorS, ctx->keyData.vectorT), ret);
int16_t s_asym[MLKEM_K_MAX][MLKEM_N >> 1];
int16_t e[MLKEM_K_MAX][MLKEM_N];
int16_t *ae[MLKEM_K_MAX] = {e[0], e[1], e[2], e[3]};
for (uint32_t i = 0; i < k; i++) {
MLKEMPointMulExtended(s_asym[i], ctx->keyData.vectorS[i], MLKEM_BASEMUL_TWIDDLE_TABLE);
}
for (uint32_t i = 0; i < k; i++) {
MLKEMAsymMulMont(k, ae[i], ctx->keyData.matrix[i][0], &(ctx->keyData.vectorS[0][0]), &(s_asym[0][0]));
}
MLKEMPolyAddReduce(k, ctx->keyData.vectorT, ae);
MLKEMPolyVecReduce(k, ctx->keyData.vectorS);
MLKEMPolyVecReduce(k, ctx->keyData.vectorT);
PolyVectoBytes(k, pk, ctx->keyData.vectorT);
PolyVectoBytes(k, dk, ctx->keyData.vectorS);
memcpy(pk + k * MLKEM_CIPHER_LEN, p, MLKEM_SEED_LEN);
ERR:
return ret;
}
int32_t MLKEM_PKEEnc(uint32_t k, MLKEM_MatrixSt *mat, uint8_t du, uint8_t dv, uint8_t *ct, int16_t *y[], int16_t *e1[],
int16_t *u[], int16_t *e2, int16_t *mu, int16_t *c2)
{
(void)dv;
(void)u;
int16_t *transMatrix[MLKEM_K_MAX][MLKEM_K_MAX] = {0};
for (uint32_t i = 0; i < k; i++) {
for (uint32_t j = 0; j < k; j++) {
transMatrix[j][i] = mat->matrix[i][j];
}
}
int16_t s_asym[MLKEM_K_MAX][MLKEM_N >> 1];
int16_t t[MLKEM_K_MAX][MLKEM_N];
int16_t *at[MLKEM_K_MAX] = {t[0], t[1], t[2], t[3]};
for (uint32_t i = 0; i < k; i++) {
MLKEMPointMulExtended(&(s_asym[i][0]), &(y[i][0]), MLKEM_BASEMUL_TWIDDLE_TABLE);
}
for (uint32_t i = 0; i < k; i++) {
MLKEMAsymMul(k, at[i], transMatrix[i], &(y[0][0]), &(s_asym[0][0]));
}
MLKEMAsymMul(k, c2, mat->vectorT, &(y[0][0]), &(s_asym[0][0]));
MLKEMPolyINTTtoMont(k, at);
MLKEM_ComputeINTTAsm(c2);
MLKEMPolyAddReduce(k, at, e1);
MLKEMAdd2Reduce(c2, e2, mu);
PackCipherText(k, MLKEM_ENCODE_BLOCKSIZE * k * du, ct, at, c2);
return CRYPT_SUCCESS;
}
int32_t MLKEM_PKEDec(uint32_t k, MLKEM_MatrixSt *mat, int16_t *m, int16_t *c1[], int16_t *c2, uint8_t *result)
{
int16_t b_asym[MLKEM_K_MAX][MLKEM_N >> 1];
for (uint32_t i = 0; i < k; i++) {
MLKEM_ComputeNTTAsm(c1[i]);
MLKEMPointMulExtended(&(b_asym[i][0]), &(c1[i][0]), MLKEM_BASEMUL_TWIDDLE_TABLE);
}
MLKEMAsymMul(k, m, mat->vectorS, &(c1[0][0]), &(b_asym[0][0]));
MLKEM_ComputeINTTAsm(m);
MLKEMSubReduce(c2, m);
MLKEMPolyCompress1(result, c2);
return CRYPT_SUCCESS;
}
void MLKEM_SamplePolyCBD(int16_t *polyF, uint8_t *buf, uint8_t eta)
{
MLKEMPolyCBDEta(eta, polyF, buf);
}
#endif