#include <errno.h>
#include <fcntl.h>
#include <netdb.h>
#include <poll.h>
#include <pthread.h>
#include <semaphore.h>
#include <stdbool.h>
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <sys/select.h>
#include <sys/socket.h>
#include <time.h>
#include <unistd.h>
#include <nuttx/config.h>
#ifdef CONFIG_CRYPTO_MBEDTLS
#include <mbedtls/ctr_drbg.h>
#ifdef CONFIG_MBEDTLS_DEBUG_C
#include <mbedtls/debug.h>
#endif
#include <mbedtls/entropy.h>
#include <mbedtls/net_sockets.h>
#include <mbedtls/ssl.h>
#include <mbedtls/x509_crt.h>
#ifdef CONFIG_MBEDTLS_DEBUG_C
static void my_debug(void *ctx, int level, const char *file, int line,
const char *str)
{
(void)ctx;
(void)level;
}
#endif
* Valid: 2009-03-18 to 2029-03-18
* Required by Alibaba Bailian/DashScope TLS connection
*/
static const char globalsign_root_r3_pem[] =
"-----BEGIN CERTIFICATE-----\r\n"
"MIIDXzCCAkegAwIBAgILBAAAAAABIVhTCKIwDQYJKoZIhvcNAQELBQAwTDEgMB4G\r\n"
"A1UECxMXR2xvYmFsU2lnbiBSb290IENBIC0gUjMxEzARBgNVBAoTCkdsb2JhbFNp\r\n"
"Z24xEzARBgNVBAMTCkdsb2JhbFNpZ24wHhcNMDkwMzE4MTAwMDAwWhcNMjkwMzE4\r\n"
"MTAwMDAwWjBMMSAwHgYDVQQLExdHbG9iYWxTaWduIFJvb3QgQ0EgLSBSMzETMBEG\r\n"
"A1UEChMKR2xvYmFsU2lnbjETMBEGA1UEAxMKR2xvYmFsU2lnbjCCASIwDQYJKoZI\r\n"
"hvcNAQEBBQADggEPADCCAQoCggEBAMwldpB5BngiFvXAg7aEyiie/QV2EcWtiHL8\r\n"
"RgJDx7KKnQRfJMsuS+FggkbhUqsMgUdwbN1k0ev1LKMPgj0MK66X17YUhhB5uzsT\r\n"
"gHeMCOFJ0mpiLx9e+pZo34knlTifBtc+ycsmWQ1z3rDI6SYOgxXG71uL0gRgykmm\r\n"
"KPZpO/bLyCiR5Z2KYVc3rHQU3HTgOu5yLy6c+9C7v/U9AOEGM+iCK65TpjoWc4zd\r\n"
"QQ4gOsC0p6Hpsk+QLjJg6VfLuQSSaGjlOCZgdbKfd/+RFO+uIEn8rUAVSNECMWEZ\r\n"
"XriX7613t2Saer9fwRPvm2L7DWzgVGkWqQPabumDk3F2xmmFghcCAwEAAaNCMEAw\r\n"
"DgYDVR0PAQH/BAQDAgEGMA8GA1UdEwEB/wQFMAMBAf8wHQYDVR0OBBYEFI/wS3+o\r\n"
"LkUkrk1Q+mOai97i3Ru8MA0GCSqGSIb3DQEBCwUAA4IBAQBLQNvAUKr+yAzv95ZU\r\n"
"RUm7lgAJQayzE4aGKAczymvmdLm6AC2upArT9fHxD4q/c2dKg8dEe3jgr25sbwMp\r\n"
"jjM5RcOO5LlXbKr8EpbsU8Yt5CRsuZRj+9xTaGdWPoO4zzUhw8lo/s7awlOqzJCK\r\n"
"6fBdRoyV3XpYKBovHd7NADdBj+1EbddTKJd+82cEHhXXipa0095MJ6RMG3NzdvQX\r\n"
"mcIfeg7jLQitChws/zyrVQ4PkX4268NXSb7hLi18YIvDQVETI53O9zJrlAGomecs\r\n"
"Mx86OyXShkDOOyyGeMlhLxS67ttVb9+E7gUJTb0o2HLO02JQZR7rkpeDMdmztcpH\r\n"
"WD9f\r\n"
"-----END CERTIFICATE-----\r\n";
#endif
#define WS_GUID "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
struct ws_url_s
{
char host[128];
char path[256];
uint16_t port;
bool tls;
};
struct ws_tls_s
{
#ifdef CONFIG_CRYPTO_MBEDTLS
mbedtls_net_context net;
mbedtls_ssl_context ssl;
mbedtls_ssl_config conf;
mbedtls_ctr_drbg_context ctr_drbg;
mbedtls_entropy_context entropy;
mbedtls_x509_crt cacert;
#endif
bool inited;
};
struct bailian_ws_client
{
int fd;
bool tls;
struct ws_tls_s tls_ctx;
};
typedef struct
{
uint32_t state[5];
uint64_t count;
uint8_t buffer[64];
} sha1_ctx_t;
static void sha1_transform(uint32_t state[5], const uint8_t buffer[64])
{
uint32_t a, b, c, d, e, t;
uint32_t w[80];
int i;
for (i = 0; i < 16; i++)
{
w[i] = ((uint32_t)buffer[i * 4] << 24) |
((uint32_t)buffer[i * 4 + 1] << 16) |
((uint32_t)buffer[i * 4 + 2] << 8) |
(uint32_t)buffer[i * 4 + 3];
}
for (i = 16; i < 80; i++)
{
uint32_t v = w[i - 3] ^ w[i - 8] ^ w[i - 14] ^ w[i - 16];
w[i] = (v << 1) | (v >> 31);
}
a = state[0];
b = state[1];
c = state[2];
d = state[3];
e = state[4];
for (i = 0; i < 80; i++)
{
uint32_t f;
uint32_t k;
if (i < 20)
{
f = (b & c) | ((~b) & d);
k = 0x5A827999;
}
else if (i < 40)
{
f = b ^ c ^ d;
k = 0x6ED9EBA1;
}
else if (i < 60)
{
f = (b & c) | (b & d) | (c & d);
k = 0x8F1BBCDC;
}
else
{
f = b ^ c ^ d;
k = 0xCA62C1D6;
}
t = ((a << 5) | (a >> 27)) + f + e + k + w[i];
e = d;
d = c;
c = (b << 30) | (b >> 2);
b = a;
a = t;
}
state[0] += a;
state[1] += b;
state[2] += c;
state[3] += d;
state[4] += e;
}
static void sha1_init(sha1_ctx_t *ctx)
{
ctx->state[0] = 0x67452301;
ctx->state[1] = 0xEFCDAB89;
ctx->state[2] = 0x98BADCFE;
ctx->state[3] = 0x10325476;
ctx->state[4] = 0xC3D2E1F0;
ctx->count = 0;
}
static void sha1_update(sha1_ctx_t *ctx, const uint8_t *data, size_t len)
{
size_t i;
size_t index = (size_t)((ctx->count >> 3) & 0x3f);
ctx->count += (uint64_t)len << 3;
for (i = 0; i < len; i++)
{
ctx->buffer[index++] = data[i];
if (index == 64)
{
sha1_transform(ctx->state, ctx->buffer);
index = 0;
}
}
}
static void sha1_final(sha1_ctx_t *ctx, uint8_t digest[20])
{
uint8_t pad[64];
uint8_t len_bytes[8];
size_t index;
size_t pad_len;
int i;
memset(pad, 0, sizeof(pad));
pad[0] = 0x80;
index = (size_t)((ctx->count >> 3) & 0x3f);
pad_len = (index < 56) ? (56 - index) : (120 - index);
for (i = 0; i < 8; i++)
{
len_bytes[7 - i] = (uint8_t)(ctx->count >> (i * 8));
}
sha1_update(ctx, pad, pad_len);
sha1_update(ctx, len_bytes, 8);
for (i = 0; i < 5; i++)
{
digest[i * 4] = (uint8_t)(ctx->state[i] >> 24);
digest[i * 4 + 1] = (uint8_t)(ctx->state[i] >> 16);
digest[i * 4 + 2] = (uint8_t)(ctx->state[i] >> 8);
digest[i * 4 + 3] = (uint8_t)(ctx->state[i]);
}
}
static void base64_encode(const uint8_t *src, size_t len, char *dst,
size_t dst_len)
{
static const char table[] =
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
size_t i;
size_t o = 0;
for (i = 0; i < len && o + 4 < dst_len; i += 3)
{
uint32_t a = src[i];
uint32_t b = (i + 1 < len) ? src[i + 1] : 0;
uint32_t c = (i + 2 < len) ? src[i + 2] : 0;
dst[o++] = table[(a >> 2) & 0x3f];
dst[o++] = table[((a & 0x3) << 4) | ((b >> 4) & 0x0f)];
if (i + 1 < len)
{
dst[o++] = table[((b & 0x0f) << 2) | ((c >> 6) & 0x03)];
}
else
{
dst[o++] = '=';
}
if (i + 2 < len)
{
dst[o++] = table[c & 0x3f];
}
else
{
dst[o++] = '=';
}
}
if (o < dst_len)
{
dst[o] = '\0';
}
}
static int parse_ws_url(const char *url, struct ws_url_s *out)
{
const char *host_start;
const char *path_start;
const char *port_start;
size_t host_len;
size_t path_len;
if (url == NULL || out == NULL)
{
return -EINVAL;
}
memset(out, 0, sizeof(*out));
if (strncmp(url, "ws://", 5) == 0)
{
out->port = 80;
out->tls = false;
host_start = url + 5;
}
else if (strncmp(url, "wss://", 6) == 0)
{
out->port = 443;
out->tls = true;
host_start = url + 6;
}
else
{
return -EINVAL;
}
path_start = strchr(host_start, '/');
if (path_start == NULL)
{
path_start = host_start + strlen(host_start);
}
port_start = memchr(host_start, ':', (size_t)(path_start - host_start));
if (port_start != NULL)
{
host_len = (size_t)(port_start - host_start);
out->port = (uint16_t)strtoul(port_start + 1, NULL, 10);
}
else
{
host_len = (size_t)(path_start - host_start);
}
if (host_len == 0 || host_len >= sizeof(out->host))
{
return -EINVAL;
}
memcpy(out->host, host_start, host_len);
out->host[host_len] = '\0';
if (*path_start == '\0')
{
strcpy(out->path, "/");
}
else
{
path_len = strlen(path_start);
if (path_len >= sizeof(out->path))
{
return -EINVAL;
}
memcpy(out->path, path_start, path_len + 1);
}
return 0;
}
static int tcp_connect(const char *host, uint16_t port, int timeout_ms)
{
struct addrinfo hints;
struct addrinfo *res = NULL;
struct addrinfo *it;
char port_str[8];
int fd = -1;
int ret;
snprintf(port_str, sizeof(port_str), "%u", port);
memset(&hints, 0, sizeof(hints));
hints.ai_family = AF_UNSPEC;
hints.ai_socktype = SOCK_STREAM;
ret = getaddrinfo(host, port_str, &hints, &res);
if (ret != 0)
{
return -EHOSTUNREACH;
}
for (it = res; it != NULL; it = it->ai_next)
{
struct timeval tv;
fd = socket(it->ai_family, it->ai_socktype, it->ai_protocol);
if (fd < 0)
{
continue;
}
tv.tv_sec = timeout_ms / 1000;
tv.tv_usec = (timeout_ms % 1000) * 1000;
setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv));
setsockopt(fd, SOL_SOCKET, SO_SNDTIMEO, &tv, sizeof(tv));
ret = connect(fd, it->ai_addr, it->ai_addrlen);
if (ret == 0)
{
break;
}
close(fd);
fd = -1;
}
freeaddrinfo(res);
return fd;
}
#ifdef CONFIG_CRYPTO_MBEDTLS
#define WS_DNS_MAX_ADDRS 8
static struct sockaddr_in g_ws_cached_addrs[WS_DNS_MAX_ADDRS];
static int g_ws_cached_count = 0;
void bailian_ws_dns_prefetch(const char *host, uint16_t port)
{
struct addrinfo hints;
struct addrinfo *res = NULL;
struct addrinfo *it;
int ret;
char port_str[8];
snprintf(port_str, sizeof(port_str), "%u", port);
memset(&hints, 0, sizeof(hints));
hints.ai_family = AF_INET;
hints.ai_socktype = SOCK_STREAM;
ret = getaddrinfo(host, port_str, &hints, &res);
if (ret == 0 && res != NULL)
{
g_ws_cached_count = 0;
for (it = res; it != NULL && g_ws_cached_count < WS_DNS_MAX_ADDRS;
it = it->ai_next)
{
if (it->ai_family == AF_INET)
{
memcpy(&g_ws_cached_addrs[g_ws_cached_count],
it->ai_addr, sizeof(struct sockaddr_in));
g_ws_cached_count++;
}
}
freeaddrinfo(res);
}
else
{
printf("[DNS] Pre-resolve failed: %d\n", ret);
fflush(stdout);
}
}
* NuttX connect() blocks even with O_NONBLOCK/SOCK_NONBLOCK set
* (FDCHECK encoded fds break non-blocking connect).
* Use a helper thread to do blocking connect, with polling
* sem_trywait for timeout (sem_timedwait may also block on NuttX).
*/
struct connect_helper_s
{
int fd;
struct sockaddr_in addr;
volatile int result;
volatile int err;
volatile bool done;
};
static void *connect_thread_func(void *arg)
{
struct connect_helper_s *h = (struct connect_helper_s *)arg;
h->result = connect(h->fd, (struct sockaddr *)&h->addr, sizeof(h->addr));
h->err = errno;
h->done = true;
return NULL;
}
* Uses a helper thread + polling for connect timeout,
* because NuttX connect() ignores O_NONBLOCK with FDCHECK.
*/
static int tls_tcp_connect_ipv4(const char *host, uint16_t port,
int timeout_ms)
{
int fd;
int i;
if (g_ws_cached_count <= 0)
{
printf("[TLS] No cached DNS for %s\n", host);
return -EHOSTUNREACH;
}
for (i = 0; i < g_ws_cached_count; i++)
{
struct connect_helper_s helper;
pthread_t tid;
int ret;
int waited_ms;
memcpy(&helper.addr, &g_ws_cached_addrs[i], sizeof(helper.addr));
helper.addr.sin_port = htons(port);
helper.result = -1;
helper.err = 0;
helper.done = false;
fd = socket(AF_INET, SOCK_STREAM, 0);
if (fd < 0)
{
printf("[TLS] socket failed: %d\n", errno);
continue;
}
helper.fd = fd;
pthread_attr_t attr;
pthread_attr_init(&attr);
pthread_attr_setstacksize(&attr, 2048);
ret = pthread_create(&tid, &attr, connect_thread_func, &helper);
pthread_attr_destroy(&attr);
if (ret != 0)
{
printf("[TLS] pthread_create failed: %d\n", ret);
close(fd);
continue;
}
waited_ms = 0;
while (!helper.done && waited_ms < timeout_ms)
{
usleep(50000);
waited_ms += 50;
}
if (!helper.done)
{
printf("[TLS] connect timeout on IP #%d after %dms\n",
i, waited_ms);
fflush(stdout);
close(fd);
int extra = 0;
while (!helper.done && extra < 2000)
{
usleep(50000);
extra += 50;
}
if (helper.done)
{
pthread_join(tid, NULL);
}
else
{
printf("[TLS] connect thread still stuck, detaching\n");
fflush(stdout);
pthread_detach(tid);
}
continue;
}
pthread_join(tid, NULL);
if (helper.result != 0)
{
printf("[TLS] connect failed on IP #%d: errno=%d\n",
i, helper.err);
close(fd);
continue;
}
goto connected;
}
printf("[TLS] All %d IPs failed\n", g_ws_cached_count);
return -EHOSTUNREACH;
connected:
{
struct timeval tv;
tv.tv_sec = timeout_ms / 1000;
tv.tv_usec = (timeout_ms % 1000) * 1000;
setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv));
setsockopt(fd, SOL_SOCKET, SO_SNDTIMEO, &tv, sizeof(tv));
}
return fd;
}
* Cannot use select() because FDCHECK encodes fd values beyond FD_SETSIZE,
* causing FD_SET() buffer overflow. poll() has no such limitation.
*/
* Replaces mbedtls_net_send to ensure write() actually delivers data
* and to diagnose potential FDCHECK issues on the send path.
*/
static int my_net_send(void *ctx, const unsigned char *buf, size_t len)
{
mbedtls_net_context *net_ctx = (mbedtls_net_context *)ctx;
struct pollfd pfd;
int ret;
pfd.fd = net_ctx->fd;
pfd.events = POLLOUT;
pfd.revents = 0;
ret = poll(&pfd, 1, 10000);
if (ret == 0)
{
}
if (ret < 0)
{
printf("[TLS-SEND] poll error=%d, fd=%d\n", errno, net_ctx->fd);
return MBEDTLS_ERR_NET_SEND_FAILED;
}
if (pfd.revents & (POLLERR | POLLHUP))
{
printf("[TLS-SEND] poll POLLERR/HUP revents=0x%x, fd=%d\n",
pfd.revents, net_ctx->fd);
return MBEDTLS_ERR_NET_CONN_RESET;
}
ret = (int)write(net_ctx->fd, buf, len);
if (ret < 0)
{
printf("[TLS-SEND] write errno=%d\n", errno);
if (errno == EAGAIN || errno == EWOULDBLOCK || errno == EINTR)
{
return MBEDTLS_ERR_SSL_WANT_WRITE;
}
if (errno == EPIPE || errno == ECONNRESET)
{
return MBEDTLS_ERR_NET_CONN_RESET;
}
return MBEDTLS_ERR_NET_SEND_FAILED;
}
return ret;
}
static int my_net_recv(void *ctx, unsigned char *buf, size_t len)
{
mbedtls_net_context *net_ctx = (mbedtls_net_context *)ctx;
struct pollfd pfd;
int ret;
pfd.fd = net_ctx->fd;
pfd.events = POLLIN;
pfd.revents = 0;
ret = poll(&pfd, 1, 10000);
if (ret == 0)
{
}
if (ret < 0)
{
printf("[TLS-RECV] poll error=%d, fd=%d\n", errno, net_ctx->fd);
return MBEDTLS_ERR_NET_RECV_FAILED;
}
if (pfd.revents & (POLLERR | POLLHUP))
{
printf("[TLS-RECV] poll POLLERR/HUP revents=0x%x, fd=%d\n",
pfd.revents, net_ctx->fd);
return MBEDTLS_ERR_NET_CONN_RESET;
}
ret = (int)read(net_ctx->fd, buf, len);
if (ret < 0)
{
if (errno == EAGAIN || errno == EWOULDBLOCK || errno == EINTR)
{
return MBEDTLS_ERR_SSL_WANT_READ;
}
if (errno == EPIPE || errno == ECONNRESET)
{
return MBEDTLS_ERR_NET_CONN_RESET;
}
return MBEDTLS_ERR_NET_RECV_FAILED;
}
if (ret == 0)
{
printf("[TLS-RECV] read returned 0 (connection closed by peer)\n");
}
return ret;
}
static int tls_connect(struct ws_tls_s *tls, const char *host,
uint16_t port, int timeout_ms)
{
int ret;
int fd;
mbedtls_net_init(&tls->net);
mbedtls_ssl_init(&tls->ssl);
mbedtls_ssl_config_init(&tls->conf);
mbedtls_ctr_drbg_init(&tls->ctr_drbg);
mbedtls_entropy_init(&tls->entropy);
mbedtls_x509_crt_init(&tls->cacert);
ret = mbedtls_x509_crt_parse(&tls->cacert,
(const unsigned char *)globalsign_root_r3_pem,
sizeof(globalsign_root_r3_pem));
if (ret != 0)
{
printf("[TLS] CA cert parse failed: -0x%x\n", -ret);
return -EIO;
}
ret = mbedtls_ctr_drbg_seed(&tls->ctr_drbg, mbedtls_entropy_func,
&tls->entropy, (const unsigned char *)"bailian",
strlen("bailian"));
if (ret != 0)
{
return -EIO;
}
fd = tls_tcp_connect_ipv4(host, port, timeout_ms);
if (fd < 0)
{
printf("[TLS] TCP connect failed: %d\n", fd);
return -EHOSTUNREACH;
}
tls->net.fd = fd;
ret = mbedtls_ssl_config_defaults(&tls->conf, MBEDTLS_SSL_IS_CLIENT,
MBEDTLS_SSL_TRANSPORT_STREAM,
MBEDTLS_SSL_PRESET_DEFAULT);
if (ret != 0)
{
return -EIO;
}
#ifdef CONFIG_MBEDTLS_DEBUG_C
mbedtls_ssl_conf_dbg(&tls->conf, my_debug, NULL);
mbedtls_debug_set_threshold(4);
#endif
* so we can see verification result but still connect if it fails
*/
mbedtls_ssl_conf_ca_chain(&tls->conf, &tls->cacert, NULL);
mbedtls_ssl_conf_authmode(&tls->conf, MBEDTLS_SSL_VERIFY_OPTIONAL);
mbedtls_ssl_conf_rng(&tls->conf, mbedtls_ctr_drbg_random, &tls->ctr_drbg);
if (timeout_ms > 0)
{
mbedtls_ssl_conf_read_timeout(&tls->conf, timeout_ms);
}
ret = mbedtls_ssl_setup(&tls->ssl, &tls->conf);
if (ret != 0)
{
return -EIO;
}
ret = mbedtls_ssl_set_hostname(&tls->ssl, host);
if (ret != 0)
{
return -EIO;
}
mbedtls_ssl_set_bio(&tls->ssl, &tls->net, my_net_send,
my_net_recv, NULL);
int hs_count = 0;
do
{
ret = mbedtls_ssl_handshake(&tls->ssl);
hs_count++;
}
while (ret == MBEDTLS_ERR_SSL_WANT_READ ||
ret == MBEDTLS_ERR_SSL_WANT_WRITE);
if (ret != 0)
{
printf("[TLS] Handshake failed: -0x%x\n", -ret);
uint32_t flags = mbedtls_ssl_get_verify_result(&tls->ssl);
if (flags != 0)
{
char vrfy_buf[256];
mbedtls_x509_crt_verify_info(vrfy_buf, sizeof(vrfy_buf),
" ! ", flags);
printf("[TLS] Verify result: %s\n", vrfy_buf);
}
fflush(stdout);
return -EIO;
}
{
uint32_t flags = mbedtls_ssl_get_verify_result(&tls->ssl);
if (flags != 0)
{
char vrfy_buf[256];
mbedtls_x509_crt_verify_info(vrfy_buf, sizeof(vrfy_buf),
" ! ", flags);
printf("[TLS] WARNING verify: %s\n", vrfy_buf);
}
else
{
printf("[TLS] Server certificate verified OK\n");
}
fflush(stdout);
}
tls->inited = true;
return 0;
}
static void tls_close(struct ws_tls_s *tls)
{
if (tls == NULL || !tls->inited)
{
return;
}
mbedtls_ssl_close_notify(&tls->ssl);
mbedtls_net_free(&tls->net);
mbedtls_ssl_free(&tls->ssl);
mbedtls_ssl_config_free(&tls->conf);
mbedtls_ctr_drbg_free(&tls->ctr_drbg);
mbedtls_entropy_free(&tls->entropy);
mbedtls_x509_crt_free(&tls->cacert);
tls->inited = false;
}
static ssize_t tls_write(struct ws_tls_s *tls, const uint8_t *data, size_t len)
{
size_t sent = 0;
while (sent < len)
{
int ret = mbedtls_ssl_write(&tls->ssl, data + sent, len - sent);
if (ret < 0)
{
if (ret == MBEDTLS_ERR_SSL_WANT_READ ||
ret == MBEDTLS_ERR_SSL_WANT_WRITE)
{
continue;
}
return -EIO;
}
sent += (size_t)ret;
}
return (ssize_t)sent;
}
static ssize_t tls_read(struct ws_tls_s *tls, uint8_t *data, size_t len)
{
int ret = mbedtls_ssl_read(&tls->ssl, data, len);
if (ret == MBEDTLS_ERR_SSL_WANT_READ || ret == MBEDTLS_ERR_SSL_WANT_WRITE)
{
return 0;
}
if (ret == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY)
{
return 0;
}
if (ret < 0)
{
return -EIO;
}
return ret;
}
#endif
static ssize_t ws_write(struct bailian_ws_client *client, const uint8_t *data,
size_t len)
{
if (client->tls)
{
#ifdef CONFIG_CRYPTO_MBEDTLS
return tls_write(&client->tls_ctx, data, len);
#else
return -ENOTSUP;
#endif
}
return send(client->fd, data, len, 0);
}
static ssize_t ws_read_once(struct bailian_ws_client *client, uint8_t *data,
size_t len, int timeout_ms)
{
if (client->tls)
{
#ifdef CONFIG_CRYPTO_MBEDTLS
return tls_read(&client->tls_ctx, data, len);
#else
return -ENOTSUP;
#endif
}
if (timeout_ms > 0)
{
struct pollfd pfd;
pfd.fd = client->fd;
pfd.events = POLLIN;
pfd.revents = 0;
if (poll(&pfd, 1, timeout_ms) <= 0)
{
return 0;
}
}
return recv(client->fd, data, len, 0);
}
* mbedtls_ssl_read may return fewer bytes than requested
* (e.g. request 8000, get 4096 from one TLS record).
* Returns len on success, <= 0 on error/timeout. */
static ssize_t ws_read_exact(struct bailian_ws_client *client, uint8_t *data,
size_t len, int timeout_ms)
{
size_t total = 0;
while (total < len)
{
ssize_t ret = ws_read_once(client, data + total, len - total, timeout_ms);
if (ret < 0)
{
return ret;
}
if (ret == 0)
{
if (total == 0) return 0;
usleep(1000);
continue;
}
total += (size_t)ret;
}
return (ssize_t)total;
}
static int ws_build_accept(const char *key, char *accept, size_t accept_len)
{
char concat[128];
uint8_t digest[20];
sha1_ctx_t ctx;
if (snprintf(concat, sizeof(concat), "%s%s", key, WS_GUID) < 0)
{
return -EINVAL;
}
sha1_init(&ctx);
sha1_update(&ctx, (const uint8_t *)concat, strlen(concat));
sha1_final(&ctx, digest);
base64_encode(digest, sizeof(digest), accept, accept_len);
return 0;
}
void bailian_ws_close(struct bailian_ws_client *client);
int bailian_ws_connect(const char *url, const char *extra_headers,
int timeout_ms, struct bailian_ws_client **out)
{
struct ws_url_s parsed;
struct bailian_ws_client *client;
uint8_t key_raw[16];
char key_b64[64];
char accept[64];
char request[512];
char response[512];
ssize_t ret;
if (parse_ws_url(url, &parsed) != 0 || out == NULL)
{
printf("[WS] URL parse failed\n");
return -EINVAL;
}
client = calloc(1, sizeof(*client));
if (client == NULL)
{
printf("[WS] calloc failed\n");
return -ENOMEM;
}
client->tls = parsed.tls;
if (parsed.tls)
{
#ifdef CONFIG_CRYPTO_MBEDTLS
ret = tls_connect(&client->tls_ctx, parsed.host, parsed.port, timeout_ms);
if (ret < 0)
{
free(client);
return (int)ret;
}
#else
free(client);
return -ENOTSUP;
#endif
}
else
{
client->fd = tcp_connect(parsed.host, parsed.port, timeout_ms);
if (client->fd < 0)
{
free(client);
return client->fd;
}
}
for (size_t i = 0; i < sizeof(key_raw); i++)
{
key_raw[i] = (uint8_t)(rand() & 0xff);
}
base64_encode(key_raw, sizeof(key_raw), key_b64, sizeof(key_b64));
ws_build_accept(key_b64, accept, sizeof(accept));
snprintf(request, sizeof(request),
"GET %s HTTP/1.1\r\n"
"Host: %s\r\n"
"Upgrade: websocket\r\n"
"Connection: Upgrade\r\n"
"Sec-WebSocket-Key: %s\r\n"
"Sec-WebSocket-Version: 13\r\n"
"%s\r\n",
parsed.path, parsed.host, key_b64,
extra_headers != NULL ? extra_headers : "");
ret = ws_write(client, (const uint8_t *)request, strlen(request));
if (ret < 0)
{
bailian_ws_close(client);
return (int)ret;
}
ret = ws_read_once(client, (uint8_t *)response, sizeof(response) - 1, timeout_ms);
if (ret <= 0)
{
bailian_ws_close(client);
return -EIO;
}
response[ret] = '\0';
if (strstr(response, " 101 ") == NULL ||
strcasestr(response, "Sec-WebSocket-Accept") == NULL ||
strstr(response, accept) == NULL)
{
bailian_ws_close(client);
return -EPROTO;
}
*out = client;
return 0;
}
int bailian_ws_send(struct bailian_ws_client *client, uint8_t opcode,
const uint8_t *data, size_t len)
{
uint8_t header[14];
uint8_t mask[4];
size_t header_len = 0;
size_t i;
if (client == NULL)
{
return -EINVAL;
}
header[0] = 0x80 | (opcode & 0x0f);
if (len < 126)
{
header[1] = 0x80 | (uint8_t)len;
header_len = 2;
}
else if (len < 65536)
{
header[1] = 0x80 | 126;
header[2] = (uint8_t)(len >> 8);
header[3] = (uint8_t)(len & 0xff);
header_len = 4;
}
else
{
header[1] = 0x80 | 127;
for (i = 0; i < 8; i++)
{
header[2 + i] = (uint8_t)((len >> ((7 - i) * 8)) & 0xff);
}
header_len = 10;
}
for (i = 0; i < 4; i++)
{
mask[i] = (uint8_t)(rand() & 0xff);
}
memcpy(&header[header_len], mask, 4);
header_len += 4;
if (ws_write(client, header, header_len) < 0)
{
return -EIO;
}
if (len > 0)
{
uint8_t *masked = malloc(len);
if (masked == NULL)
{
return -ENOMEM;
}
for (i = 0; i < len; i++)
{
masked[i] = data[i] ^ mask[i % 4];
}
if (ws_write(client, masked, len) < 0)
{
free(masked);
return -EIO;
}
free(masked);
}
return 0;
}
int bailian_ws_recv(struct bailian_ws_client *client, uint8_t *opcode,
uint8_t *data, size_t len, int timeout_ms)
{
uint8_t header[2];
uint8_t ext_len[8];
size_t payload_len;
int masked;
ssize_t ret;
if (client == NULL || opcode == NULL || data == NULL)
{
return -EINVAL;
}
ret = ws_read_once(client, header, sizeof(header), timeout_ms);
if (ret <= 0)
{
return (int)ret;
}
*opcode = header[0] & 0x0f;
masked = (header[1] & 0x80) != 0;
payload_len = (header[1] & 0x7f);
if (payload_len == 126)
{
ret = ws_read_exact(client, ext_len, 2, timeout_ms);
if (ret != 2)
{
return -EIO;
}
payload_len = ((size_t)ext_len[0] << 8) | ext_len[1];
}
else if (payload_len == 127)
{
ret = ws_read_exact(client, ext_len, 8, timeout_ms);
if (ret != 8)
{
return -EIO;
}
payload_len = 0;
for (int i = 0; i < 8; i++)
{
payload_len = (payload_len << 8) | ext_len[i];
}
}
if (payload_len > len)
{
printf("[WS] frame too large: %zu > %zu, draining\n",
payload_len, len);
size_t skip = payload_len + (masked ? 4 : 0);
uint8_t drain[512];
while (skip > 0)
{
size_t chunk = skip < sizeof(drain) ? skip : sizeof(drain);
ret = ws_read_once(client, drain, chunk, timeout_ms);
if (ret <= 0)
{
return -EIO;
}
skip -= (size_t)ret;
}
return 0;
}
if (masked)
{
uint8_t mask[4];
ret = ws_read_exact(client, mask, sizeof(mask), timeout_ms);
if (ret != 4)
{
return -EIO;
}
ret = ws_read_exact(client, data, payload_len, timeout_ms);
if (ret != (ssize_t)payload_len)
{
return -EIO;
}
for (size_t i = 0; i < payload_len; i++)
{
data[i] ^= mask[i % 4];
}
}
else
{
ret = ws_read_exact(client, data, payload_len, timeout_ms);
if (ret != (ssize_t)payload_len)
{
return -EIO;
}
}
if (*opcode == 0x9)
{
bailian_ws_send(client, 0xA, data, payload_len);
return 0;
}
return (int)payload_len;
}
void bailian_ws_close(struct bailian_ws_client *client)
{
if (client == NULL)
{
return;
}
if (client->tls)
{
#ifdef CONFIG_CRYPTO_MBEDTLS
tls_close(&client->tls_ctx);
#endif
}
else if (client->fd >= 0)
{
close(client->fd);
}
free(client);
}