* This file is part of the openHiTLS project.
*
* openHiTLS is licensed under the Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
*
* http://license.coscl.org.cn/MulanPSL2
*
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
*/
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <unistd.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <netdb.h>
#include <errno.h>
#include <fcntl.h>
#include "app_errno.h"
#include "app_print.h"
#include "app_utils.h"
#include "app_provider.h"
#include "app_keymgmt.h"
#include "hitls_config.h"
#include "hitls_cert.h"
#include "hitls_pki_cert.h"
#include "hitls_type.h"
#include "cipher_suite.h"
#include "hitls_session.h"
#include "hitls_cert_type.h"
#include "crypt_eal_pkey.h"
#include "crypt_errno.h"
#include "bsl_bytes.h"
#include "bsl_params.h"
#include "bsl_sal.h"
#include "bsl_err.h"
#include "sal_file.h"
#include "crypt_params_key.h"
#include "app_tls_common.h"
#define HEARTBEAT_STR "heartbeat"
static CRYPT_EAL_PkeyCtx *LoadKeyFromFile(APP_CertConfig *certConfig, bool isSignKey);
APP_ProtocolType ParseProtocolType(const char *protocolStr)
{
if (protocolStr == NULL) {
return APP_PROTOCOL_TLS;
}
if (strcmp(protocolStr, "tls") == 0) {
return APP_PROTOCOL_TLS;
} else if (strcmp(protocolStr, "tlcp") == 0) {
return APP_PROTOCOL_TLCP;
} else if (strcmp(protocolStr, "dtlcp") == 0) {
return APP_PROTOCOL_DTLCP;
}
return APP_PROTOCOL_TLS;
}
HITLS_Config *CreateProtocolConfig(APP_ProtocolType protocol, AppProvider *provider)
{
HITLS_Config *config = NULL;
switch (protocol) {
case APP_PROTOCOL_TLS:
config = HITLS_CFG_ProviderNewTLSConfig(APP_GetCurrent_LibCtx(), provider->providerAttr);
break;
case APP_PROTOCOL_TLCP:
config = HITLS_CFG_ProviderNewTLCPConfig(APP_GetCurrent_LibCtx(), provider->providerAttr);
break;
case APP_PROTOCOL_DTLCP:
config = HITLS_CFG_ProviderNewDTLCPConfig(APP_GetCurrent_LibCtx(), provider->providerAttr);
break;
default:
AppPrintError("Unsupported protocol type: %d\n", protocol);
return NULL;
}
if (config == NULL) {
AppPrintError("Failed to create protocol configuration\n");
}
if (protocol == APP_PROTOCOL_TLS) {
int32_t ret = HITLS_CFG_SetVersionForbid(config, TLCP_VERSION_BITS);
if (ret != HITLS_SUCCESS) {
HITLS_CFG_FreeConfig(config);
AppPrintError("Failed to disable TLCP for TLS protocol, errCode: 0x%x.\n", ret);
return NULL;
}
}
#ifdef HITLS_APP_SM_MODE
int32_t ret = HITLS_CFG_SetSessionTicketSupport(config, false);
if (ret != HITLS_SUCCESS) {
HITLS_CFG_FreeConfig(config);
AppPrintError("Failed to set session ticket support, errCode: 0x%x.\n", ret);
return NULL;
}
#endif
return config;
}
int ConfigureCipherSuites(HITLS_Config *config, const char *cipherStr, APP_ProtocolType protocol)
{
if (config == NULL || cipherStr == NULL) {
return HITLS_APP_INVALID_ARG;
}
int32_t ret;
uint32_t protocolVersion = 0;
bool needVersionCheck = false;
if (protocol == APP_PROTOCOL_DTLCP || protocol == APP_PROTOCOL_TLCP) {
protocolVersion = HITLS_VERSION_TLCP_DTLCP11;
needVersionCheck = true;
}
char *cipherStrCopy = BSL_SAL_Malloc(strlen(cipherStr) + 1);
if (cipherStrCopy == NULL) {
AppPrintError("Failed to allocate memory for cipher string\n");
return HITLS_APP_MEM_ALLOC_FAIL;
}
strcpy(cipherStrCopy, cipherStr);
uint32_t cipherCount = 1;
for (const char *p = cipherStr; *p != '\0'; p++) {
if (*p == ':') {
cipherCount++;
}
}
uint16_t *cipherSuites = BSL_SAL_Malloc(sizeof(uint16_t) * cipherCount);
if (cipherSuites == NULL) {
BSL_SAL_Free(cipherStrCopy);
AppPrintError("Failed to allocate memory for cipher suites\n");
return HITLS_APP_MEM_ALLOC_FAIL;
}
uint32_t index = 0;
char *nextTmp = NULL;
char *token = strtok_r(cipherStrCopy, ":", &nextTmp);
while (token != NULL && index < cipherCount) {
const HITLS_Cipher *cipher = HITLS_CFG_GetCipherSuiteByStdName((const uint8_t *)token);
if (cipher == NULL) {
AppPrintError("Invalid cipher suite: %s\n", token);
BSL_SAL_Free(cipherStrCopy);
BSL_SAL_Free(cipherSuites);
return HITLS_APP_ERR_SET_CIPHER;
}
if (needVersionCheck) {
if (protocolVersion < cipher->minVersion || protocolVersion > cipher->maxVersion) {
AppPrintError("Protocol (%d) not in cipher suite version range [%d, %d]!\n",
protocolVersion, cipher->minVersion, cipher->maxVersion);
BSL_SAL_Free(cipherStrCopy);
BSL_SAL_Free(cipherSuites);
return HITLS_APP_ERR_SET_CIPHER;
}
}
ret = HITLS_CFG_GetCipherSuite(cipher, &cipherSuites[index]);
if (ret != HITLS_SUCCESS) {
AppPrintError("Failed to get cipher suite for %s: 0x%x\n", token, ret);
BSL_SAL_Free(cipherStrCopy);
BSL_SAL_Free(cipherSuites);
return HITLS_APP_ERR_SET_CIPHER;
}
index++;
token = strtok_r(NULL, ":", &nextTmp);
}
ret = HITLS_CFG_SetCipherSuites(config, cipherSuites, index);
BSL_SAL_Free(cipherStrCopy);
BSL_SAL_Free(cipherSuites);
if (ret != HITLS_SUCCESS) {
AppPrintError("Failed to set cipher suites: 0x%x\n", ret);
return HITLS_APP_ERR_SET_CIPHER;
}
return HITLS_APP_SUCCESS;
}
typedef struct {
const char *name;
BSL_ParseFormat format;
} FormatMapEntry;
static const FormatMapEntry FORMAT_MAP[] = {
{"ASN1", BSL_FORMAT_ASN1},
{"PEM", BSL_FORMAT_PEM},
};
const char *GetFormatName(BSL_ParseFormat format)
{
for (size_t i = 0; i < sizeof(FORMAT_MAP)/sizeof(FORMAT_MAP[0]); ++i) {
if (FORMAT_MAP[i].format == format) {
return FORMAT_MAP[i].name;
}
}
return NULL;
}
HITLS_X509_Cert *LoadCertFromFile(const char *certFile, BSL_ParseFormat format, AppProvider *provider)
{
if (certFile == NULL) {
return NULL;
}
const char *formatName = GetFormatName(format);
uint8_t *data = NULL;
uint32_t dataLen = 0;
HITLS_X509_Cert *cert = NULL;
int32_t ret = BSL_SAL_ReadFile(certFile, &data, &dataLen);
if (ret != BSL_SUCCESS) {
return NULL;
}
BSL_Buffer encode = {data, dataLen};
ret = HITLS_X509_ProviderCertParseBuff(APP_GetCurrent_LibCtx(), provider->providerAttr, formatName, &encode, &cert);
if (ret != HITLS_SUCCESS) {
BSL_SAL_Free(data);
AppPrintError("Failed to load certificate from %s: 0x%x\n", certFile, ret);
return NULL;
}
BSL_SAL_Free(data);
return cert;
}
#ifdef HITLS_APP_SM_MODE
static int32_t GetPkeyCtxFromUuid(AppProvider *provider, HITLS_APP_SM_Param *smParam, char *uuid,
CRYPT_EAL_PkeyCtx **ctx)
{
HITLS_APP_KeyInfo keyInfo = {0};
HITLS_APP_SM_Param param = {0};
memcpy(¶m, smParam, sizeof(HITLS_APP_SM_Param));
param.uuid = uuid;
int32_t ret = HITLS_APP_FindKey(provider, ¶m, CRYPT_PKEY_SM2, &keyInfo);
if (ret != HITLS_APP_SUCCESS) {
AppPrintError("Failed to find key, errCode: 0x%x\n", ret);
return ret;
}
*ctx = keyInfo.pkeyCtx;
return HITLS_APP_SUCCESS;
}
#endif
#ifdef HITLS_APP_SM_MODE
static int32_t ReadEncKeyCipher(const char *cipherFile, uint8_t **cipher, uint32_t *cipherLen)
{
int32_t ret = BSL_SAL_ReadFile(cipherFile, cipher, cipherLen);
if (ret != BSL_SUCCESS) {
AppPrintError("Failed to read encrypted private key from %s\n", cipherFile);
}
return ret;
}
static int32_t DecryptEncKeyWithSign(CRYPT_EAL_PkeyCtx *signKey, const uint8_t *cipher, uint32_t cipherLen,
uint8_t **plain, uint32_t *plainLen)
{
uint8_t *buf = BSL_SAL_Malloc(cipherLen);
if (buf == NULL) {
AppPrintError("Failed to allocate memory for decrypted private key\n");
return HITLS_APP_MEM_ALLOC_FAIL;
}
uint32_t outLen = cipherLen;
int32_t ret = CRYPT_EAL_PkeyDecrypt(signKey, cipher, cipherLen, buf, &outLen);
if (ret != CRYPT_SUCCESS) {
AppPrintError("Failed to decrypt encrypted private key: 0x%x\n", ret);
BSL_SAL_ClearFree(buf, cipherLen);
return ret;
}
*plain = buf;
*plainLen = outLen;
return CRYPT_SUCCESS;
}
static CRYPT_EAL_PkeyCtx *CreateSm2PkeyFromPrv(AppProvider *provider, uint8_t *plain, uint32_t plainLen)
{
CRYPT_EAL_PkeyCtx *encKey = CRYPT_EAL_ProviderPkeyNewCtx(APP_GetCurrent_LibCtx(), CRYPT_PKEY_SM2, 0,
provider->providerAttr);
if (encKey == NULL) {
AppPrintError("Failed to create pkey context for decrypted private key\n");
return NULL;
}
BSL_Param prvParam[] = {{0}, BSL_PARAM_END};
(void)BSL_PARAM_InitValue(&prvParam[0], CRYPT_PARAM_EC_PRVKEY, BSL_PARAM_TYPE_OCTETS,
(void *)plain, plainLen);
int32_t ret = CRYPT_EAL_PkeySetPrvEx(encKey, prvParam);
if (ret != CRYPT_SUCCESS) {
AppPrintError("Failed to set decrypted private key: 0x%x\n", ret);
CRYPT_EAL_PkeyFreeCtx(encKey);
return NULL;
}
return encKey;
}
static CRYPT_EAL_PkeyCtx *LoadEncKeyBySignKey(APP_CertConfig *certConfig)
{
AppProvider *provider = certConfig->provider;
const char *cipherFile = certConfig->tlcpEncKey;
CRYPT_EAL_PkeyCtx *signKey = NULL;
CRYPT_EAL_PkeyCtx *encKey = NULL;
uint8_t *cipher = NULL;
uint8_t *plain = NULL;
uint32_t cipherLen = 0;
uint32_t plainLen = 0;
if (ReadEncKeyCipher(cipherFile, &cipher, &cipherLen) != BSL_SUCCESS) {
return NULL;
}
signKey = LoadKeyFromFile(certConfig, true);
if (signKey == NULL) {
AppPrintError("Failed to load TLCP signature private key for decrypt\n");
goto ERR;
}
if (DecryptEncKeyWithSign(signKey, cipher, cipherLen, &plain, &plainLen) != CRYPT_SUCCESS) {
goto ERR;
}
encKey = CreateSm2PkeyFromPrv(provider, plain, plainLen);
if (encKey == NULL) {
goto ERR;
}
CRYPT_EAL_PkeyFreeCtx(signKey);
BSL_SAL_Free(cipher);
BSL_SAL_ClearFree(plain, cipherLen);
return encKey;
ERR:
if (encKey != NULL) {
CRYPT_EAL_PkeyFreeCtx(encKey);
}
if (signKey != NULL) {
CRYPT_EAL_PkeyFreeCtx(signKey);
}
if (cipher != NULL) {
BSL_SAL_Free(cipher);
}
if (plain != NULL) {
BSL_SAL_ClearFree(plain, cipherLen);
}
return NULL;
}
#endif
static CRYPT_EAL_PkeyCtx *LoadKeyFromFile(APP_CertConfig *certConfig, bool isSignKey)
{
char *keyFile = isSignKey ? certConfig->tlcpSignKey : certConfig->tlcpEncKey;
BSL_ParseFormat format = certConfig->keyFormat;
const char *password = certConfig->keyPass;
AppProvider *provider = certConfig->provider;
if (keyFile == NULL) {
return NULL;
}
CRYPT_EAL_PkeyCtx *pkey = NULL;
#ifdef HITLS_APP_SM_MODE
if (isSignKey && certConfig->smParam->smTag == 1) {
int32_t ret = GetPkeyCtxFromUuid(provider, certConfig->smParam, keyFile, &pkey);
if (ret == HITLS_APP_SUCCESS) {
return pkey;
}
}
if (!isSignKey && certConfig->smParam->smTag == 1) {
pkey = LoadEncKeyBySignKey(certConfig);
if (pkey != NULL) {
return pkey;
}
}
#endif
char *pass = NULL;
if (password != NULL) {
size_t len = strlen(password) + 1;
pass = BSL_SAL_Malloc(len);
if (pass != NULL) {
strcpy(pass, password);
}
}
pkey = HITLS_APP_ProviderLoadPrvKey(APP_GetCurrent_LibCtx(), provider->providerAttr, keyFile, format, &pass);
if (pkey == NULL) {
AppPrintError("Failed to load private key from %s\n", keyFile);
}
if (pass != NULL) {
BSL_SAL_ClearFree(pass, strlen(pass));
}
return pkey;
}
int ConfCertVerification(HITLS_Config *config, APP_CertConfig *certConfig,
bool verifyPeer, int verifyDepth)
{
if (config == NULL) {
return HITLS_APP_INVALID_ARG;
}
int ret = HITLS_SUCCESS;
bool hasLoadedCA = false;
if (certConfig && certConfig->caFile) {
HITLS_X509_Cert *ca_cert = LoadCertFromFile(certConfig->caFile, certConfig->certFormat, certConfig->provider);
if (ca_cert == NULL) {
AppPrintError("Failed to load CA certificate from %s\n", certConfig->caFile);
return HITLS_APP_ERR_LOAD_CA;
}
ret = HITLS_CFG_AddCertToStore(config, ca_cert, TLS_CERT_STORE_TYPE_DEFAULT, true);
if (ret != HITLS_SUCCESS) {
AppPrintError("Failed to add CA certificate to store: 0x%x\n", ret);
HITLS_X509_CertFree(ca_cert);
return HITLS_APP_ERR_LOAD_CA;
}
HITLS_X509_CertFree(ca_cert);
hasLoadedCA = true;
}
if (certConfig && certConfig->caChain) {
HITLS_X509_List *certlist = NULL;
ret = HITLS_X509_CertParseBundleFile(certConfig->certFormat, certConfig->caChain, &certlist);
if (ret != BSL_SUCCESS) {
AppPrintError("Failed to parse certificate <%s>, errCode = %d.\n", certConfig->caChain, ret);
return HITLS_APP_X509_FAIL;
}
for (BslListNode *node = BSL_LIST_FirstNode(certlist); node != NULL;
node = BSL_LIST_GetNextNode(certlist, node)) {
HITLS_X509_Cert *cert = BSL_LIST_GetData(node);
ret = HITLS_CFG_AddCertToStore(config, cert, TLS_CERT_STORE_TYPE_DEFAULT, true);
if (ret != HITLS_SUCCESS) {
AppPrintError("Failed to add CA-chain certificate to store: 0x%x\n", ret);
ret = HITLS_APP_ERR_LOAD_CA;
break;
}
}
BSL_LIST_FREE(certlist, (BSL_LIST_PFUNC_FREE)HITLS_X509_CertFree);
if (ret != HITLS_SUCCESS) {
return ret;
}
hasLoadedCA = true;
}
if (!hasLoadedCA) {
ret = HITLS_CFG_LoadDefaultCAPath(config);
if (ret != HITLS_SUCCESS) {
AppPrintError("Failed to load default CA path: 0x%x\n", ret);
}
}
ret = HITLS_CFG_SetVerifyNoneSupport(config, !verifyPeer);
if (ret != HITLS_SUCCESS) {
AppPrintError("Failed to disable server verification: 0x%x\n", ret);
return HITLS_APP_ERR_SET_VERIFY;
}
ret = HITLS_CFG_SetClientVerifySupport(config, verifyPeer);
if (ret != HITLS_SUCCESS) {
AppPrintError("Failed to set client verification: 0x%x\n", ret);
return HITLS_APP_ERR_SET_VERIFY;
}
if (verifyPeer) {
ret = HITLS_CFG_SetNoClientCertSupport(config, false);
if (ret != HITLS_SUCCESS) {
AppPrintError("Failed to require client certificate: 0x%x\n", ret);
return HITLS_APP_ERR_SET_VERIFY;
}
}
if (verifyDepth > 0) {
ret = HITLS_CFG_SetVerifyDepth(config, verifyDepth);
if (ret != HITLS_SUCCESS) {
AppPrintError("Failed to set verification depth: 0x%x\n", ret);
return HITLS_APP_ERR_SET_VERIFY;
}
}
return HITLS_APP_SUCCESS;
}
int ConfigureTLCPCertificates(HITLS_Config *config, APP_CertConfig *certConfig)
{
if (config == NULL || certConfig == NULL) {
return HITLS_APP_INVALID_ARG;
}
int ret = HITLS_SUCCESS;
if (certConfig->tlcpSignCert && certConfig->tlcpSignKey) {
HITLS_X509_Cert *sign_cert = LoadCertFromFile(certConfig->tlcpSignCert, certConfig->certFormat,
certConfig->provider);
CRYPT_EAL_PkeyCtx *sign_key = LoadKeyFromFile(certConfig, true);
if (sign_cert && sign_key) {
ret = HITLS_CFG_SetTlcpCertificate(config, sign_cert, false, false);
if (ret != HITLS_SUCCESS) {
HITLS_X509_CertFree(sign_cert);
CRYPT_EAL_PkeyFreeCtx(sign_key);
AppPrintError("Failed to set TLCP signature certificate: 0x%x\n", ret);
return HITLS_APP_ERR_SET_TLCP_CERT;
}
ret = HITLS_CFG_SetTlcpPrivateKey(config, sign_key, false, false);
if (ret != HITLS_SUCCESS) {
CRYPT_EAL_PkeyFreeCtx(sign_key);
AppPrintError("Failed to set TLCP signature private key: 0x%x\n", ret);
return HITLS_APP_ERR_SET_TLCP_CERT;
}
} else {
HITLS_X509_CertFree(sign_cert);
CRYPT_EAL_PkeyFreeCtx(sign_key);
return HITLS_APP_ERR_SET_TLCP_CERT;
}
}
if (certConfig->tlcpEncCert && certConfig->tlcpEncKey) {
HITLS_X509_Cert *enc_cert = LoadCertFromFile(certConfig->tlcpEncCert, certConfig->certFormat,
certConfig->provider);
CRYPT_EAL_PkeyCtx *enc_key = LoadKeyFromFile(certConfig, false);
if (enc_cert && enc_key) {
ret = HITLS_CFG_SetTlcpCertificate(config, enc_cert, false, true);
if (ret != HITLS_SUCCESS) {
HITLS_X509_CertFree(enc_cert);
CRYPT_EAL_PkeyFreeCtx(enc_key);
AppPrintError("Failed to set TLCP encryption certificate: 0x%x\n", ret);
return HITLS_APP_ERR_SET_TLCP_CERT;
}
ret = HITLS_CFG_SetTlcpPrivateKey(config, enc_key, false, true);
if (ret != HITLS_SUCCESS) {
CRYPT_EAL_PkeyFreeCtx(enc_key);
AppPrintError("Failed to set TLCP encryption private key: 0x%x\n", ret);
return HITLS_APP_ERR_SET_TLCP_CERT;
}
} else {
HITLS_X509_CertFree(enc_cert);
CRYPT_EAL_PkeyFreeCtx(enc_key);
return HITLS_APP_ERR_SET_TLCP_CERT;
}
}
return HITLS_APP_SUCCESS;
}
int CreateTCPSocket(APP_NetworkAddr *addr, int timeout)
{
if (addr == NULL || addr->host == NULL) {
return -1;
}
int sockfd = BSL_SAL_Socket(AF_INET, SOCK_STREAM, 0);
if (sockfd < 0) {
AppPrintError("Failed to create socket: %s\n", strerror(errno));
return -1;
}
if (timeout > 0) {
struct timeval tv;
tv.tv_sec = timeout;
tv.tv_usec = 0;
BSL_SAL_SetSockopt(sockfd, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv));
BSL_SAL_SetSockopt(sockfd, SOL_SOCKET, SO_SNDTIMEO, &tv, sizeof(tv));
}
struct sockaddr_in serverAdd;
memset(&serverAdd, 0, sizeof(serverAdd));
serverAdd.sin_family = AF_INET;
serverAdd.sin_port = htons(addr->port);
if (inet_pton(AF_INET, addr->host, &serverAdd.sin_addr) <= 0) {
struct hostent *hostEntry = gethostbyname(addr->host);
if (hostEntry == NULL) {
AppPrintError("Failed to resolve hostname: %s\n", addr->host);
BSL_SAL_SockClose(sockfd);
return -1;
}
memcpy(&serverAdd.sin_addr, hostEntry->h_addr_list[0], (size_t)hostEntry->h_length);
}
if (BSL_SAL_SockConnect(sockfd, (BSL_SAL_SockAddr)&serverAdd, sizeof(serverAdd)) < 0) {
AppPrintError("Failed to connect to %s:%d: %s\n", addr->host, addr->port, strerror(errno));
BSL_SAL_SockClose(sockfd);
return -1;
}
return sockfd;
}
int CreateUDPSocket(APP_NetworkAddr *addr, int timeout)
{
(void)timeout;
if (addr == NULL || addr->host == NULL) {
return -1;
}
int sockfd = BSL_SAL_Socket(AF_INET, SOCK_DGRAM, 0);
if (sockfd < 0) {
AppPrintError("Failed to create UDP socket: %s\n", strerror(errno));
return -1;
}
struct sockaddr_in serverAdd;
memset(&serverAdd, 0, sizeof(serverAdd));
serverAdd.sin_family = AF_INET;
serverAdd.sin_port = htons(addr->port);
if (inet_pton(AF_INET, addr->host, &serverAdd.sin_addr) <= 0) {
struct hostent *hostEntry = gethostbyname(addr->host);
if (hostEntry == NULL) {
AppPrintError("Failed to resolve hostname: %s\n", addr->host);
BSL_SAL_SockClose(sockfd);
return -1;
}
memcpy(&serverAdd.sin_addr, hostEntry->h_addr_list[0], (size_t)hostEntry->h_length);
}
if (BSL_SAL_SockConnect(sockfd, (BSL_SAL_SockAddr)&serverAdd, sizeof(serverAdd)) < 0) {
AppPrintError("Failed to connect UDP socket to %s:%d: %s\n", addr->host, addr->port, strerror(errno));
BSL_SAL_SockClose(sockfd);
return -1;
}
return sockfd;
}
int CreateTCPListenSocket(APP_NetworkAddr *addr, int backlog)
{
int sockfd = BSL_SAL_Socket(AF_INET, SOCK_STREAM, 0);
if (sockfd < 0) {
AppPrintError("Failed to create listen socket: %s\n", strerror(errno));
return -1;
}
int opt = 1;
BSL_SAL_SetSockopt(sockfd, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt));
struct sockaddr_in bindAddr;
memset(&bindAddr, 0, sizeof(bindAddr));
bindAddr.sin_family = AF_INET;
bindAddr.sin_port = htons(addr->port);
if (addr->host && strcmp(addr->host, "0.0.0.0") != 0) {
if (inet_pton(AF_INET, addr->host, &bindAddr.sin_addr) <= 0) {
AppPrintError("Invalid bind address: %s\n", addr->host);
BSL_SAL_SockClose(sockfd);
return -1;
}
} else {
bindAddr.sin_addr.s_addr = INADDR_ANY;
}
if (BSL_SAL_SockBind(sockfd, (BSL_SAL_SockAddr)&bindAddr, sizeof(bindAddr)) < 0) {
AppPrintError("Failed to bind to %s:%d: %s\n",
addr->host ? addr->host : "0.0.0.0", addr->port, strerror(errno));
BSL_SAL_SockClose(sockfd);
return -1;
}
if (BSL_SAL_SockListen(sockfd, backlog) < 0) {
AppPrintError("Failed to listen: %s\n", strerror(errno));
BSL_SAL_SockClose(sockfd);
return -1;
}
return sockfd;
}
int CreateUDPListenSocket(APP_NetworkAddr *addr, int timeout)
{
int sockfd = BSL_SAL_Socket(AF_INET, SOCK_DGRAM, 0);
if (sockfd < 0) {
AppPrintError("Failed to create UDP listen socket: %s\n", strerror(errno));
return -1;
}
if (timeout > 0) {
struct timeval tv;
tv.tv_sec = timeout;
tv.tv_usec = 0;
BSL_SAL_SetSockopt(sockfd, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv));
BSL_SAL_SetSockopt(sockfd, SOL_SOCKET, SO_SNDTIMEO, &tv, sizeof(tv));
}
struct sockaddr_in bindAddr;
memset(&bindAddr, 0, sizeof(bindAddr));
bindAddr.sin_family = AF_INET;
bindAddr.sin_port = htons(addr->port);
if (addr->host && strcmp(addr->host, "0.0.0.0") != 0) {
if (inet_pton(AF_INET, addr->host, &bindAddr.sin_addr) <= 0) {
AppPrintError("Invalid bind address: %s\n", addr->host);
BSL_SAL_SockClose(sockfd);
return -1;
}
} else {
bindAddr.sin_addr.s_addr = INADDR_ANY;
}
if (BSL_SAL_SockBind(sockfd, (BSL_SAL_SockAddr)&bindAddr, sizeof(bindAddr)) < 0) {
AppPrintError("Failed to bind UDP to %s:%d: %s\n",
addr->host ? addr->host : "0.0.0.0", addr->port, strerror(errno));
BSL_SAL_SockClose(sockfd);
return -1;
}
return sockfd;
}
int AcceptTCPConnection(int listenFd)
{
struct sockaddr_in clientAddr;
socklen_t addrLen = sizeof(clientAddr);
int flags = fcntl(listenFd, F_GETFL, 0);
fcntl(listenFd, F_SETFL, flags | O_NONBLOCK);
int clientFd = accept(listenFd, (struct sockaddr *)&clientAddr, &addrLen);
if (clientFd < 0) {
return -1;
}
char clientIp[INET_ADDRSTRLEN];
inet_ntop(AF_INET, &clientAddr.sin_addr, clientIp, INET_ADDRSTRLEN);
AppPrintInfo("Accepted connection from %s:%d\n", clientIp, ntohs(clientAddr.sin_port));
return clientFd;
}
void PrintConnectionInfo(HITLS_Ctx *ctx, bool showState)
{
if (ctx == NULL) {
return;
}
uint16_t version;
if (HITLS_GetNegotiatedVersion(ctx, &version) == HITLS_SUCCESS) {
AppPrintInfo("Protocol version: ");
switch (version) {
case HITLS_VERSION_TLS12:
AppPrintInfo("TLSv1.2\n");
break;
case HITLS_VERSION_TLS13:
AppPrintInfo("TLSv1.3\n");
break;
case HITLS_VERSION_DTLS12:
AppPrintInfo("DTLSv1.2\n");
break;
case HITLS_VERSION_TLCP_DTLCP11:
AppPrintInfo("TLCP v1.1\n");
break;
default:
AppPrintInfo("Unknown (0x%04x)\n", version);
break;
}
}
const HITLS_Cipher *cipher = HITLS_GetCurrentCipher(ctx);
if (cipher != NULL) {
AppPrintInfo("Cipher suite negotiated\n");
}
if (showState) {
PrintHandshakeState(ctx);
}
}
void PrintHandshakeState(HITLS_Ctx *ctx)
{
if (ctx == NULL) {
return;
}
uint32_t state;
if (HITLS_GetHandShakeState(ctx, &state) == HITLS_SUCCESS) {
const char *stateStr = HITLS_GetStateString(state);
AppPrintInfo("Handshake state: %s\n", stateStr ? stateStr : "Unknown");
}
}
int ParseConnectString(const char *connectStr, APP_NetworkAddr *addr)
{
if (connectStr == NULL || addr == NULL) {
return HITLS_APP_INVALID_ARG;
}
size_t len = strlen(connectStr) + 1;
char *strCopy = BSL_SAL_Malloc(len);
if (strCopy == NULL) {
AppPrintError("Failed to alloc memory.\n");
return HITLS_APP_MEM_ALLOC_FAIL;
}
strcpy(strCopy, connectStr);
char *colon_pos = strrchr(strCopy, ':');
if (colon_pos == NULL) {
addr->host = strCopy;
addr->port = 443;
return HITLS_APP_SUCCESS;
}
*colon_pos = '\0';
size_t host_len = strlen(strCopy) + 1;
addr->host = BSL_SAL_Malloc(host_len);
if (addr->host == NULL) {
BSL_SAL_Free(strCopy);
AppPrintError("Failed to alloc memory.");
return HITLS_APP_MEM_ALLOC_FAIL;
}
(void)strcpy(addr->host, strCopy);
addr->port = atoi(colon_pos + 1);
BSL_SAL_Free(strCopy);
if (addr->port <= 0 || addr->port > 65535) {
BSL_SAL_Free(addr->host);
addr->host = NULL;
AppPrintError("Invalid port number in connect string\n");
return HITLS_APP_INVALID_ARG;
}
return HITLS_APP_SUCCESS;
}
#ifdef HITLS_APP_SM_MODE
int32_t GetHeartBeat(uint8_t *buffer, uint32_t *len)
{
if (buffer == NULL || len == NULL || *len < APP_HEARTBEAT_LEN) {
AppPrintError("Invalid buffer or length.\n");
return HITLS_APP_INVALID_ARG;
}
int64_t time = 0;
int ret = HITLS_APP_GetTime(&time);
if (ret != HITLS_APP_SUCCESS) {
AppPrintError("Failed to get time, errCode: 0x%x.\n", ret);
return ret;
}
BSL_Uint64ToByte(time, (uint8_t *)&time);
const char *heartBeat = HEARTBEAT_STR;
memcpy(buffer, heartBeat, strlen(heartBeat));
memcpy(buffer + strlen(heartBeat), &time, sizeof(time));
*len = APP_HEARTBEAT_LEN;
return HITLS_APP_SUCCESS;
}
int32_t ParseHeartBeat(uint8_t *buffer, uint32_t len)
{
if (buffer == NULL || len != APP_HEARTBEAT_LEN) {
AppPrintError("Invalid buffer or length.\n");
return HITLS_APP_INVALID_ARG;
}
int ret = strncmp((const char *)buffer, HEARTBEAT_STR, strlen(HEARTBEAT_STR));
if (ret != 0) {
AppPrintError("Invalid heartbeat string.\n");
return HITLS_APP_INVALID_ARG;
}
int64_t time = 0;
memcpy(&time, buffer + strlen(HEARTBEAT_STR), sizeof(time));
time = BSL_ByteToUint64((uint8_t *)&time);
return HITLS_APP_SUCCESS;
}
#endif