* This file is part of the openHiTLS project.
*
* openHiTLS is licensed under the Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
*
* http://license.coscl.org.cn/MulanPSL2
*
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
*/
#include "hitls_build.h"
#ifdef HITLS_CRYPTO_HSS
#include <string.h>
#include "bsl_sal.h"
#include "bsl_bytes.h"
#include "bsl_err_internal.h"
#include "crypt_errno.h"
#include "crypt_local_types.h"
#include "crypt_params_key.h"
#include "hss_local.h"
CRYPT_HSS_Ctx *CRYPT_HSS_NewCtx(void)
{
CRYPT_HSS_Ctx *ctx = (CRYPT_HSS_Ctx *)BSL_SAL_Calloc(1, sizeof(CRYPT_HSS_Ctx));
if (ctx == NULL) {
return NULL;
}
return ctx;
}
CRYPT_HSS_Ctx *CRYPT_HSS_NewCtxEx(void *libCtx)
{
CRYPT_HSS_Ctx *ctx = CRYPT_HSS_NewCtx();
if (ctx == NULL) {
return NULL;
}
ctx->libCtx = libCtx;
return ctx;
}
void CRYPT_HSS_FreeCtx(CRYPT_HSS_Ctx *ctx)
{
if (ctx == NULL) {
return;
}
BSL_SAL_ClearFree(ctx->privateKey, HSS_PRVKEY_LEN);
* even though its bytes are public-by-spec (algorithm IDs, identifier I,
* root hash), scrubbing them on free aligns the two sibling APIs and
* keeps no structured crypto material lingering in the heap chunk. */
BSL_SAL_ClearFree(ctx->publicKey, HSS_PUBKEY_LEN);
for (uint32_t i = 0; i < HSS_LEVELS_ARRAY_SIZE; i++) {
BSL_SAL_ClearFree(ctx->cachedTrees[i], ctx->cachedTreeSizes[i]);
}
BSL_SAL_ClearFree(ctx, sizeof(CRYPT_HSS_Ctx));
}
CRYPT_HSS_Ctx *CRYPT_HSS_DupCtx(CRYPT_HSS_Ctx *srcCtx)
{
if (srcCtx == NULL) {
return NULL;
}
CRYPT_HSS_Ctx *newCtx = (CRYPT_HSS_Ctx *)CRYPT_HSS_NewCtx();
if (newCtx == NULL) {
return NULL;
}
newCtx->para = srcCtx->para;
if (srcCtx->publicKey != NULL) {
newCtx->publicKey = (uint8_t *)BSL_SAL_Calloc(1, HSS_PUBKEY_LEN);
if (newCtx->publicKey == NULL) {
CRYPT_HSS_FreeCtx(newCtx);
return NULL;
}
memcpy(newCtx->publicKey, srcCtx->publicKey, HSS_PUBKEY_LEN);
}
newCtx->signatureIndex = 0;
return newCtx;
}
int32_t CRYPT_HSS_Cmp(CRYPT_HSS_Ctx *ctx1, CRYPT_HSS_Ctx *ctx2)
{
if (ctx1 == NULL || ctx2 == NULL) {
BSL_ERR_PUSH_ERROR(CRYPT_HSS_CMP_FALSE);
return CRYPT_HSS_CMP_FALSE;
}
if (ctx1->para.levels != ctx2->para.levels) {
BSL_ERR_PUSH_ERROR(CRYPT_HSS_CMP_FALSE);
return CRYPT_HSS_CMP_FALSE;
}
for (uint32_t i = 0; i < ctx1->para.levels; i++) {
if (ctx1->para.lmsType[i] != ctx2->para.lmsType[i] || ctx1->para.otsType[i] != ctx2->para.otsType[i]) {
BSL_ERR_PUSH_ERROR(CRYPT_HSS_CMP_FALSE);
return CRYPT_HSS_CMP_FALSE;
}
}
if ((ctx1->publicKey == NULL) != (ctx2->publicKey == NULL)) {
BSL_ERR_PUSH_ERROR(CRYPT_HSS_CMP_FALSE);
return CRYPT_HSS_CMP_FALSE;
}
if (ctx1->publicKey != NULL) {
if (ConstTimeMemcmp(ctx1->publicKey, ctx2->publicKey, HSS_PUBKEY_LEN) == 0) {
BSL_ERR_PUSH_ERROR(CRYPT_HSS_CMP_FALSE);
return CRYPT_HSS_CMP_FALSE;
}
}
if ((ctx1->privateKey == NULL) != (ctx2->privateKey == NULL)) {
BSL_ERR_PUSH_ERROR(CRYPT_HSS_CMP_FALSE);
return CRYPT_HSS_CMP_FALSE;
}
if (ctx1->privateKey != NULL) {
if (ConstTimeMemcmp(ctx1->privateKey, ctx2->privateKey, HSS_PRVKEY_LEN) == 0) {
BSL_ERR_PUSH_ERROR(CRYPT_HSS_CMP_FALSE);
return CRYPT_HSS_CMP_FALSE;
}
}
return CRYPT_SUCCESS;
}
* @ingroup hss
* @brief Set number of hierarchy levels
* @param ctx [IN/OUT] HSS context
* @param val [IN] Levels value (1-8)
* @param valLen [IN] Value length (must be sizeof(uint32_t))
* @return CRYPT_SUCCESS on success, error code on failure
*/
static int32_t HssCtrlSetLevels(CRYPT_HSS_Ctx *ctx, void *val, uint32_t valLen)
{
if (valLen < sizeof(uint32_t)) {
BSL_ERR_PUSH_ERROR(CRYPT_HSS_INVALID_PARAM);
return CRYPT_HSS_INVALID_PARAM;
}
uint32_t levels = *(uint32_t *)val;
if (levels < HSS_MIN_LEVELS || levels > HSS_MAX_COMPRESSED_LEVELS) {
BSL_ERR_PUSH_ERROR(CRYPT_HSS_INVALID_LEVEL);
return CRYPT_HSS_INVALID_LEVEL;
}
ctx->para.levels = levels;
return CRYPT_SUCCESS;
}
* @ingroup hss
* @brief Set LMS type for a specific level
* @param ctx [IN/OUT] HSS context
* @param val [IN] Array: [level_index, lms_type]
* @param valLen [IN] Value length (must be 2 * sizeof(uint32_t))
* @return CRYPT_SUCCESS on success, error code on failure
*/
static int32_t HssCtrlSetLmsType(CRYPT_HSS_Ctx *ctx, void *val, uint32_t valLen)
{
if (valLen < 2 * sizeof(uint32_t)) {
BSL_ERR_PUSH_ERROR(CRYPT_HSS_INVALID_PARAM);
return CRYPT_HSS_INVALID_PARAM;
}
uint32_t *params = (uint32_t *)val;
uint32_t levelIdx = params[0];
uint32_t lmsType = params[1];
if (levelIdx >= ctx->para.levels || levelIdx >= HSS_MAX_LEVELS) {
BSL_ERR_PUSH_ERROR(CRYPT_HSS_LEVEL_OUT_OF_RANGE);
return CRYPT_HSS_LEVEL_OUT_OF_RANGE;
}
if (lmsType < LMS_SHA256_M32_H5 || lmsType > LMS_SHA256_M32_H25) {
BSL_ERR_PUSH_ERROR(CRYPT_HSS_INVALID_PARAM);
return CRYPT_HSS_INVALID_PARAM;
}
ctx->para.lmsType[levelIdx] = lmsType;
return CRYPT_SUCCESS;
}
* @ingroup hss
* @brief Set OTS type for a specific level
* @param ctx [IN/OUT] HSS context
* @param val [IN] Array: [level_index, ots_type]
* @param valLen [IN] Value length (must be 2 * sizeof(uint32_t))
* @return CRYPT_SUCCESS on success, error code on failure
*/
static int32_t HssCtrlSetOtsType(CRYPT_HSS_Ctx *ctx, void *val, uint32_t valLen)
{
if (valLen < 2 * sizeof(uint32_t)) {
BSL_ERR_PUSH_ERROR(CRYPT_HSS_INVALID_PARAM);
return CRYPT_HSS_INVALID_PARAM;
}
uint32_t *params = (uint32_t *)val;
uint32_t levelIdx = params[0];
uint32_t otsType = params[1];
if (levelIdx >= ctx->para.levels || levelIdx >= HSS_MAX_LEVELS) {
BSL_ERR_PUSH_ERROR(CRYPT_HSS_LEVEL_OUT_OF_RANGE);
return CRYPT_HSS_LEVEL_OUT_OF_RANGE;
}
if (otsType < LMOTS_SHA256_N32_W1 || otsType > LMOTS_SHA256_N32_W8) {
BSL_ERR_PUSH_ERROR(CRYPT_HSS_INVALID_PARAM);
return CRYPT_HSS_INVALID_PARAM;
}
ctx->para.otsType[levelIdx] = otsType;
return CRYPT_SUCCESS;
}
* @ingroup hss
* @brief Get public key length
* @param val [OUT] Public key length (always 60)
* @param valLen [IN] Value buffer length (must be sizeof(uint32_t))
* @return CRYPT_SUCCESS on success, error code on failure
*/
static int32_t HssCtrlGetPubKeyLen(void *val, uint32_t valLen)
{
if (valLen < sizeof(uint32_t)) {
BSL_ERR_PUSH_ERROR(CRYPT_HSS_INVALID_PARAM);
return CRYPT_HSS_INVALID_PARAM;
}
*(uint32_t *)val = HSS_PUBKEY_LEN;
return CRYPT_SUCCESS;
}
* @ingroup hss
* @brief Get private key length
* @param val [OUT] Private key length (always 48)
* @param valLen [IN] Value buffer length (must be sizeof(uint32_t))
* @return CRYPT_SUCCESS on success, error code on failure
*/
static int32_t HssCtrlGetPrvKeyLen(void *val, uint32_t valLen)
{
if (valLen < sizeof(uint32_t)) {
BSL_ERR_PUSH_ERROR(CRYPT_HSS_INVALID_PARAM);
return CRYPT_HSS_INVALID_PARAM;
}
*(uint32_t *)val = HSS_PRVKEY_LEN;
return CRYPT_SUCCESS;
}
* @ingroup hss
* @brief Get signature length
* @param ctx [IN] HSS context
* @param val [OUT] Signature length
* @param valLen [IN] Value buffer length (must be sizeof(uint32_t))
* @return CRYPT_SUCCESS on success, error code on failure
*/
static int32_t HssCtrlGetSigLen(CRYPT_HSS_Ctx *ctx, void *val, uint32_t valLen)
{
if (valLen < sizeof(uint32_t)) {
BSL_ERR_PUSH_ERROR(CRYPT_HSS_INVALID_PARAM);
return CRYPT_HSS_INVALID_PARAM;
}
if (ctx->para.pubKeyLen == 0) {
int32_t ret = HssParaInit(&ctx->para, ctx->para.levels, ctx->para.lmsType, ctx->para.otsType);
if (ret != CRYPT_SUCCESS) {
BSL_ERR_PUSH_ERROR(ret);
return ret;
}
}
size_t sigLen = HssGetSignatureLen(&ctx->para);
if (sigLen == 0) {
BSL_ERR_PUSH_ERROR(CRYPT_HSS_INVALID_PARAM);
return CRYPT_HSS_INVALID_PARAM;
}
*(uint32_t *)val = (uint32_t)sigLen;
return CRYPT_SUCCESS;
}
* @ingroup hss
* @brief Get remaining signature capacity
* @param ctx [IN] HSS context
* @param val [OUT] Remaining signatures
* @param valLen [IN] Value buffer length (must be sizeof(uint64_t))
* @return CRYPT_SUCCESS on success, error code on failure
*/
static int32_t HssCtrlGetRemaining(CRYPT_HSS_Ctx *ctx, void *val, uint32_t valLen)
{
if (valLen < sizeof(uint64_t)) {
BSL_ERR_PUSH_ERROR(CRYPT_HSS_INVALID_PARAM);
return CRYPT_HSS_INVALID_PARAM;
}
if (ctx->para.levels == 0 || ctx->privateKey == NULL) {
BSL_ERR_PUSH_ERROR(CRYPT_HSS_NO_KEY);
return CRYPT_HSS_NO_KEY;
}
if (ctx->para.pubKeyLen == 0) {
int32_t ret = HssParaInit(&ctx->para, ctx->para.levels, ctx->para.lmsType, ctx->para.otsType);
if (ret != CRYPT_SUCCESS) {
BSL_ERR_PUSH_ERROR(ret);
return ret;
}
}
uint64_t maxSigs = HssGetMaxSignatures(&ctx->para);
uint64_t counter = LmsGetBigendian(ctx->privateKey + HSS_PRVKEY_COUNTER_OFFSET, HSS_PRVKEY_COUNTER_LEN);
uint64_t remaining = (maxSigs > 0 && counter < maxSigs) ? (maxSigs - counter) : 0;
*(uint64_t *)val = remaining;
return CRYPT_SUCCESS;
}
* @ingroup hss
* @brief Get number of hierarchy levels
* @param ctx [IN] HSS context
* @param val [OUT] Number of levels
* @param valLen [IN] Value buffer length (must be sizeof(uint32_t))
* @return CRYPT_SUCCESS on success, error code on failure
*/
static int32_t HssCtrlGetLevels(CRYPT_HSS_Ctx *ctx, void *val, uint32_t valLen)
{
if (valLen < sizeof(uint32_t)) {
BSL_ERR_PUSH_ERROR(CRYPT_HSS_INVALID_PARAM);
return CRYPT_HSS_INVALID_PARAM;
}
*(uint32_t *)val = ctx->para.levels;
return CRYPT_SUCCESS;
}
* @ingroup hss
* @brief Set HSS parameters by algorithm ID
* @param ctx [IN/OUT] HSS context
* @param val [IN] Algorithm ID value
* @param valLen [IN] Value length (must be sizeof(int32_t))
* @return CRYPT_SUCCESS on success, error code on failure
*/
static int32_t HssCtrlSetParaById(CRYPT_HSS_Ctx *ctx, void *val, uint32_t valLen)
{
if (ctx->para.levels != 0) {
BSL_ERR_PUSH_ERROR(CRYPT_HSS_CTRL_INIT_REPEATED);
return CRYPT_HSS_CTRL_INIT_REPEATED;
}
if (valLen != sizeof(int32_t)) {
BSL_ERR_PUSH_ERROR(CRYPT_INVALID_ARG);
return CRYPT_INVALID_ARG;
}
int32_t algId = *(int32_t *)val;
uint32_t levels;
uint32_t lmsTypes[HSS_LEVELS_ARRAY_SIZE] = {0};
uint32_t otsTypes[HSS_LEVELS_ARRAY_SIZE] = {0};
switch (algId) {
case CRYPT_HSS_SHA256_L2_H10_H10:
levels = 2;
lmsTypes[0] = LMS_SHA256_M32_H10;
lmsTypes[1] = LMS_SHA256_M32_H10;
otsTypes[0] = LMOTS_SHA256_N32_W4;
otsTypes[1] = LMOTS_SHA256_N32_W4;
break;
case CRYPT_HSS_SHA256_L2_H15_H15:
levels = 2;
lmsTypes[0] = LMS_SHA256_M32_H15;
lmsTypes[1] = LMS_SHA256_M32_H15;
otsTypes[0] = LMOTS_SHA256_N32_W4;
otsTypes[1] = LMOTS_SHA256_N32_W4;
break;
case CRYPT_HSS_SHA256_L2_H20_H20:
levels = 2;
lmsTypes[0] = LMS_SHA256_M32_H20;
lmsTypes[1] = LMS_SHA256_M32_H20;
otsTypes[0] = LMOTS_SHA256_N32_W4;
otsTypes[1] = LMOTS_SHA256_N32_W4;
break;
case CRYPT_HSS_SHA256_L3_H10_H10_H10:
levels = 3;
lmsTypes[0] = LMS_SHA256_M32_H10;
lmsTypes[1] = LMS_SHA256_M32_H10;
lmsTypes[2] = LMS_SHA256_M32_H10;
otsTypes[0] = LMOTS_SHA256_N32_W4;
otsTypes[1] = LMOTS_SHA256_N32_W4;
otsTypes[2] = LMOTS_SHA256_N32_W4;
break;
default:
BSL_ERR_PUSH_ERROR(CRYPT_HSS_INVALID_PARAM);
return CRYPT_HSS_INVALID_PARAM;
}
return HssParaInit(&ctx->para, levels, lmsTypes, otsTypes);
}
int32_t CRYPT_HSS_Ctrl(CRYPT_HSS_Ctx *ctx, int32_t cmd, void *val, uint32_t valLen)
{
if (ctx == NULL || val == NULL) {
BSL_ERR_PUSH_ERROR(CRYPT_NULL_INPUT);
return CRYPT_NULL_INPUT;
}
switch (cmd) {
case CRYPT_CTRL_SET_PARA_BY_ID:
return HssCtrlSetParaById(ctx, val, valLen);
case CRYPT_CTRL_HSS_SET_LEVELS:
return HssCtrlSetLevels(ctx, val, valLen);
case CRYPT_CTRL_HSS_SET_LMS_TYPE:
return HssCtrlSetLmsType(ctx, val, valLen);
case CRYPT_CTRL_HSS_SET_OTS_TYPE:
return HssCtrlSetOtsType(ctx, val, valLen);
case CRYPT_CTRL_HSS_GET_PUBKEY_LEN:
return HssCtrlGetPubKeyLen(val, valLen);
case CRYPT_CTRL_HSS_GET_PRVKEY_LEN:
return HssCtrlGetPrvKeyLen(val, valLen);
case CRYPT_CTRL_HSS_GET_SIG_LEN:
return HssCtrlGetSigLen(ctx, val, valLen);
case CRYPT_CTRL_HSS_GET_REMAINING:
return HssCtrlGetRemaining(ctx, val, valLen);
case CRYPT_CTRL_HSS_GET_LEVELS:
return HssCtrlGetLevels(ctx, val, valLen);
default:
BSL_ERR_PUSH_ERROR(CRYPT_HSS_INVALID_CMD);
return CRYPT_HSS_INVALID_CMD;
}
}
static void HssInvalidateAllTreeCaches(CRYPT_HSS_Ctx *ctx)
{
for (uint32_t i = 0; i < HSS_LEVELS_ARRAY_SIZE; i++) {
BSL_SAL_ClearFree(ctx->cachedTrees[i], ctx->cachedTreeSizes[i]);
ctx->cachedTrees[i] = NULL;
ctx->cachedTreeSizes[i] = 0;
ctx->treeCacheValid[i] = false;
ctx->cachedTreeIndex[i] = 0;
}
}
int32_t CRYPT_HSS_SetPrvKey(CRYPT_HSS_Ctx *ctx, BSL_Param *param)
{
if (ctx == NULL || param == NULL) {
BSL_ERR_PUSH_ERROR(CRYPT_NULL_INPUT);
return CRYPT_NULL_INPUT;
}
const BSL_Param *prvKeyParam = BSL_PARAM_FindConstParam(param, CRYPT_PARAM_HSS_PRVKEY);
if (prvKeyParam == NULL || prvKeyParam->value == NULL) {
BSL_ERR_PUSH_ERROR(CRYPT_HSS_NO_KEY);
return CRYPT_HSS_NO_KEY;
}
if (prvKeyParam->valueLen != HSS_PRVKEY_LEN) {
BSL_ERR_PUSH_ERROR(CRYPT_HSS_INVALID_KEY_LEN);
return CRYPT_HSS_INVALID_KEY_LEN;
}
uint8_t compressed[HSS_COMPRESSED_PARAMS_LEN];
memcpy(compressed, (const uint8_t *)prvKeyParam->value + HSS_PRVKEY_PARAMS_OFFSET, HSS_PRVKEY_PARAMS_LEN);
HSS_Para newPara;
memset(&newPara, 0, sizeof(newPara));
int32_t ret = HssDecompressParamSet(&newPara, compressed);
if (ret != CRYPT_SUCCESS) {
BSL_ERR_PUSH_ERROR(ret);
return ret;
}
if (ctx->privateKey == NULL) {
ctx->privateKey = (uint8_t *)BSL_SAL_Calloc(1, HSS_PRVKEY_LEN);
if (ctx->privateKey == NULL) {
BSL_ERR_PUSH_ERROR(CRYPT_MEM_ALLOC_FAIL);
return CRYPT_MEM_ALLOC_FAIL;
}
}
memcpy(ctx->privateKey, prvKeyParam->value, HSS_PRVKEY_LEN);
memcpy(&ctx->para, &newPara, sizeof(HSS_Para));
HssInvalidateAllTreeCaches(ctx);
ctx->signatureIndex = LmsGetBigendian(ctx->privateKey + HSS_PRVKEY_COUNTER_OFFSET, HSS_PRVKEY_COUNTER_LEN);
return CRYPT_SUCCESS;
}
static int32_t HssValidatePubKeyTypes(uint32_t levels, uint32_t lmsType, uint32_t otsType)
{
if (levels < HSS_MIN_LEVELS || levels > HSS_MAX_COMPRESSED_LEVELS) {
BSL_ERR_PUSH_ERROR(CRYPT_HSS_INVALID_PARAM);
return CRYPT_HSS_INVALID_PARAM;
}
if (LmsLookupParamSet(lmsType, NULL, NULL, NULL) != CRYPT_SUCCESS) {
BSL_ERR_PUSH_ERROR(CRYPT_HSS_INVALID_PARAM);
return CRYPT_HSS_INVALID_PARAM;
}
LmOtsParams otsCheck;
if (LmOtsLookupParamSet(otsType, &otsCheck) != CRYPT_SUCCESS) {
BSL_ERR_PUSH_ERROR(CRYPT_HSS_INVALID_PARAM);
return CRYPT_HSS_INVALID_PARAM;
}
return CRYPT_SUCCESS;
}
int32_t CRYPT_HSS_SetPubKey(CRYPT_HSS_Ctx *ctx, BSL_Param *param)
{
if (ctx == NULL || param == NULL) {
BSL_ERR_PUSH_ERROR(CRYPT_NULL_INPUT);
return CRYPT_NULL_INPUT;
}
const BSL_Param *pubKeyParam = BSL_PARAM_FindConstParam(param, CRYPT_PARAM_HSS_PUBKEY);
if (pubKeyParam == NULL || pubKeyParam->value == NULL) {
BSL_ERR_PUSH_ERROR(CRYPT_HSS_NO_KEY);
return CRYPT_HSS_NO_KEY;
}
if (pubKeyParam->valueLen != HSS_PUBKEY_LEN) {
BSL_ERR_PUSH_ERROR(CRYPT_HSS_INVALID_KEY_LEN);
return CRYPT_HSS_INVALID_KEY_LEN;
}
const uint8_t *keyData = (const uint8_t *)pubKeyParam->value;
uint32_t levels = (uint32_t)LmsGetBigendian(keyData + HSS_PUBKEY_LEVELS_OFFSET, LMS_TYPE_LEN);
uint32_t lmsType = (uint32_t)LmsGetBigendian(keyData + HSS_PUBKEY_LMS_TYPE_OFFSET, LMS_TYPE_LEN);
uint32_t otsType = (uint32_t)LmsGetBigendian(keyData + HSS_PUBKEY_OTS_TYPE_OFFSET, LMS_TYPE_LEN);
int32_t ret = HssValidatePubKeyTypes(levels, lmsType, otsType);
if (ret != CRYPT_SUCCESS) {
return ret;
}
if (ctx->publicKey == NULL) {
ctx->publicKey = (uint8_t *)BSL_SAL_Calloc(1, HSS_PUBKEY_LEN);
if (ctx->publicKey == NULL) {
BSL_ERR_PUSH_ERROR(CRYPT_MEM_ALLOC_FAIL);
return CRYPT_MEM_ALLOC_FAIL;
}
}
memcpy(ctx->publicKey, pubKeyParam->value, HSS_PUBKEY_LEN);
ctx->para.levels = levels;
ctx->para.lmsType[0] = lmsType;
ctx->para.otsType[0] = otsType;
return CRYPT_SUCCESS;
}
int32_t CRYPT_HSS_GetPrvKey(CRYPT_HSS_Ctx *ctx, BSL_Param *param)
{
if (ctx == NULL || param == NULL) {
BSL_ERR_PUSH_ERROR(CRYPT_NULL_INPUT);
return CRYPT_NULL_INPUT;
}
if (ctx->privateKey == NULL) {
BSL_ERR_PUSH_ERROR(CRYPT_HSS_NO_KEY);
return CRYPT_HSS_NO_KEY;
}
BSL_Param *prv = BSL_PARAM_FindParam(param, CRYPT_PARAM_HSS_PRVKEY);
if (prv == NULL || prv->value == NULL) {
BSL_ERR_PUSH_ERROR(CRYPT_NULL_INPUT);
return CRYPT_NULL_INPUT;
}
if (prv->valueLen < HSS_PRVKEY_LEN) {
BSL_ERR_PUSH_ERROR(CRYPT_HSS_INVALID_KEY_LEN);
return CRYPT_HSS_INVALID_KEY_LEN;
}
memcpy(prv->value, ctx->privateKey, HSS_PRVKEY_LEN);
prv->useLen = HSS_PRVKEY_LEN;
return CRYPT_SUCCESS;
}
int32_t CRYPT_HSS_GetPubKey(CRYPT_HSS_Ctx *ctx, BSL_Param *param)
{
if (ctx == NULL || param == NULL) {
BSL_ERR_PUSH_ERROR(CRYPT_NULL_INPUT);
return CRYPT_NULL_INPUT;
}
if (ctx->publicKey == NULL) {
BSL_ERR_PUSH_ERROR(CRYPT_HSS_NO_KEY);
return CRYPT_HSS_NO_KEY;
}
BSL_Param *pub = BSL_PARAM_FindParam(param, CRYPT_PARAM_HSS_PUBKEY);
if (pub == NULL || pub->value == NULL) {
BSL_ERR_PUSH_ERROR(CRYPT_NULL_INPUT);
return CRYPT_NULL_INPUT;
}
if (pub->valueLen < HSS_PUBKEY_LEN) {
BSL_ERR_PUSH_ERROR(CRYPT_HSS_INVALID_KEY_LEN);
return CRYPT_HSS_INVALID_KEY_LEN;
}
memcpy(pub->value, ctx->publicKey, HSS_PUBKEY_LEN);
pub->useLen = HSS_PUBKEY_LEN;
return CRYPT_SUCCESS;
}
#ifdef HITLS_CRYPTO_HSS_CHECK
* @ingroup hss
* @brief Verify basic HSS parameters match between public and private keys
* @param pubKey [IN] Public key context
* @param prvKey [IN] Private key context
* @return CRYPT_SUCCESS if parameters match, error code otherwise
*/
static int32_t HSSCheckBasicParams(const CRYPT_HSS_Ctx *pubKey, const CRYPT_HSS_Ctx *prvKey)
{
uint32_t pubLevels = (uint32_t)LmsGetBigendian(pubKey->publicKey + HSS_PUBKEY_LEVELS_OFFSET, LMS_TYPE_LEN);
if (pubLevels != prvKey->para.levels) {
BSL_ERR_PUSH_ERROR(CRYPT_HSS_PAIRWISE_CHECK_FAIL);
return CRYPT_HSS_PAIRWISE_CHECK_FAIL;
}
uint32_t pubLmsType = (uint32_t)LmsGetBigendian(pubKey->publicKey + HSS_PUBKEY_LMS_TYPE_OFFSET, LMS_TYPE_LEN);
uint32_t pubOtsType = (uint32_t)LmsGetBigendian(pubKey->publicKey + HSS_PUBKEY_OTS_TYPE_OFFSET, LMS_TYPE_LEN);
if (pubLmsType != prvKey->para.lmsType[0] || pubOtsType != prvKey->para.otsType[0]) {
BSL_ERR_PUSH_ERROR(CRYPT_HSS_PAIRWISE_CHECK_FAIL);
return CRYPT_HSS_PAIRWISE_CHECK_FAIL;
}
return CRYPT_SUCCESS;
}
* @ingroup hss
* @brief Verify root hash matches between public and private keys
* @param pubKey [IN] Public key context
* @param prvKey [IN] Private key context
* @param rootI [IN] Root tree identifier (16 bytes)
* @param rootSeed [IN] Root tree seed (32 bytes)
* @return CRYPT_SUCCESS if root hash matches, error code otherwise
*/
static int32_t HSSVerifyRootHash(const CRYPT_HSS_Ctx *pubKey, const CRYPT_HSS_Ctx *prvKey, const uint8_t *rootI,
const uint8_t *rootSeed)
{
if (ConstTimeMemcmp(rootI, pubKey->publicKey + HSS_PUBKEY_I_OFFSET, LMS_I_LEN) == 0) {
BSL_ERR_PUSH_ERROR(CRYPT_HSS_PAIRWISE_CHECK_FAIL);
return CRYPT_HSS_PAIRWISE_CHECK_FAIL;
}
LMS_Para lmsPara;
int32_t ret = LmsParaInit(&lmsPara, prvKey->para.lmsType[0], prvKey->para.otsType[0]);
if (ret != CRYPT_SUCCESS) {
BSL_ERR_PUSH_ERROR(ret);
return ret;
}
uint8_t computedRoot[LMS_SHA256_N];
ret = LmsComputeRoot(computedRoot, &lmsPara, rootI, rootSeed);
if (ret != CRYPT_SUCCESS) {
BSL_ERR_PUSH_ERROR(ret);
return ret;
}
int32_t cmpRet = ConstTimeMemcmp(computedRoot, pubKey->publicKey + HSS_PUBKEY_ROOT_OFFSET, LMS_SHA256_N);
BSL_SAL_CleanseData(computedRoot, sizeof(computedRoot));
if (cmpRet == 0) {
BSL_ERR_PUSH_ERROR(CRYPT_HSS_PAIRWISE_CHECK_FAIL);
return CRYPT_HSS_PAIRWISE_CHECK_FAIL;
}
return CRYPT_SUCCESS;
}
* @ingroup hss
* @brief Verify HSS key pair consistency
* @param pubKey [IN] Public key context
* @param prvKey [IN] Private key context
* @return CRYPT_SUCCESS if keys match, error code otherwise
*/
static int32_t HSSKeyPairCheck(const CRYPT_HSS_Ctx *pubKey, const CRYPT_HSS_Ctx *prvKey)
{
if (pubKey == NULL || prvKey == NULL) {
BSL_ERR_PUSH_ERROR(CRYPT_NULL_INPUT);
return CRYPT_NULL_INPUT;
}
if (pubKey->para.levels == 0 || prvKey->para.levels == 0) {
BSL_ERR_PUSH_ERROR(CRYPT_NULL_INPUT);
return CRYPT_NULL_INPUT;
}
if (pubKey->publicKey == NULL || prvKey->privateKey == NULL) {
BSL_ERR_PUSH_ERROR(CRYPT_HSS_NO_KEY);
return CRYPT_HSS_NO_KEY;
}
int32_t ret = HSSCheckBasicParams(pubKey, prvKey);
if (ret != CRYPT_SUCCESS) {
BSL_ERR_PUSH_ERROR(ret);
return ret;
}
uint8_t masterSeed[LMS_SEED_LEN];
memcpy(masterSeed, prvKey->privateKey + HSS_PRVKEY_SEED_OFFSET, HSS_PRVKEY_SEED_LEN);
uint8_t rootI[LMS_I_LEN];
uint8_t rootSeed[LMS_SEED_LEN];
ret = HssGenerateRootSeed(rootI, rootSeed, masterSeed);
if (ret != CRYPT_SUCCESS) {
BSL_SAL_CleanseData(masterSeed, sizeof(masterSeed));
BSL_ERR_PUSH_ERROR(ret);
return ret;
}
ret = HSSVerifyRootHash(pubKey, prvKey, rootI, rootSeed);
* root-tree WOTS+ seed. Both must be scrubbed on every exit path of this
* helper. rootI is public-by-spec but cleansed for consistency with
* LmsKeyGen/CRYPT_HSS_Gen. */
BSL_SAL_CleanseData(masterSeed, sizeof(masterSeed));
BSL_SAL_CleanseData(rootSeed, sizeof(rootSeed));
BSL_SAL_CleanseData(rootI, sizeof(rootI));
if (ret != CRYPT_SUCCESS) {
BSL_ERR_PUSH_ERROR(ret);
}
return ret;
}
* @ingroup hss
* @brief Verify private key validity
* @param prvKey [IN] Private key context
* @return CRYPT_SUCCESS if valid, error code otherwise
*/
static int32_t HSSPrvKeyCheck(const CRYPT_HSS_Ctx *prvKey)
{
if (prvKey == NULL) {
BSL_ERR_PUSH_ERROR(CRYPT_NULL_INPUT);
return CRYPT_NULL_INPUT;
}
if (prvKey->para.levels == 0) {
BSL_ERR_PUSH_ERROR(CRYPT_NULL_INPUT);
return CRYPT_NULL_INPUT;
}
if (prvKey->privateKey == NULL) {
BSL_ERR_PUSH_ERROR(CRYPT_HSS_NO_KEY);
return CRYPT_HSS_NO_KEY;
}
return CRYPT_SUCCESS;
}
int32_t CRYPT_HSS_Check(uint32_t checkType, const CRYPT_HSS_Ctx *pkey1, const CRYPT_HSS_Ctx *pkey2)
{
switch (checkType) {
case CRYPT_PKEY_CHECK_KEYPAIR:
return HSSKeyPairCheck(pkey1, pkey2);
case CRYPT_PKEY_CHECK_PRVKEY:
return HSSPrvKeyCheck(pkey1);
default:
BSL_ERR_PUSH_ERROR(CRYPT_INVALID_ARG);
return CRYPT_INVALID_ARG;
}
}
#endif
#endif