* Copyright (c) 2023 Huawei Device Co., Ltd.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <atomic>
#include <thread>
#include "CommandLineInterface.h"
#include "PreviewerEngineLog.h"
#include "WebSocketServer.h"
lws* WebSocketServer::webSocket = nullptr;
std::atomic<bool> WebSocketServer::interrupted = false;
WebSocketServer::WebSocketState WebSocketServer::webSocketWritable = WebSocketState::INIT;
uint8_t* WebSocketServer::firstImageBuffer = nullptr;
uint64_t WebSocketServer::firstImagebufferSize = 0;
WebSocketServer::WebSocketServer() : serverThread(nullptr), serverPort(0)
{
protocols[0] = {"ws", WebSocketServer::ProtocolCallback, 0, MAX_PAYLOAD_SIZE};
protocols[1] = {NULL, NULL, 0, 0};
}
WebSocketServer::~WebSocketServer() {}
WebSocketServer& WebSocketServer::GetInstance()
{
static WebSocketServer server;
return server;
}
void WebSocketServer::SetServerPort(int port)
{
serverPort = port;
}
void WebSocketServer::SetSid(const std::string curSid)
{
sid = curSid;
}
bool WebSocketServer::CheckSid(struct lws* wsi)
{
if (WebSocketServer::GetInstance().sid.empty()) {
return true;
}
std::string uri(WebSocketServer::sidMaxLength, '\0');
int uriLength = lws_hdr_copy(wsi, &uri[0], uri.size(), WSI_TOKEN_GET_URI);
if (uriLength <= 0) {
return false;
}
uri.resize(uriLength);
if (uri.empty()) {
return false;
}
const std::string::size_type start = uri.find_last_of('/');
if (start == std::string::npos) {
return false;
}
std::string curSid = uri.substr(start + 1);
if (curSid == WebSocketServer::GetInstance().sid) {
ILOG("Websocket path validation passed.");
return true;
} else {
ELOG("Websocket path validation failed.");
return false;
}
}
int WebSocketServer::ProtocolCallback(struct lws* wsi, enum lws_callback_reasons reason,
void* user, void* in, size_t len)
{
switch (reason) {
case LWS_CALLBACK_FILTER_PROTOCOL_CONNECTION:
if (!CheckSid(wsi)) {
return 1;
}
break;
case LWS_CALLBACK_PROTOCOL_INIT:
ILOG("Engine Websocket protocol init");
break;
case LWS_CALLBACK_ESTABLISHED:
ILOG("Websocket client connect");
webSocket = wsi;
lws_callback_on_writable(wsi);
break;
case LWS_CALLBACK_RECEIVE:
break;
case LWS_CALLBACK_SERVER_WRITEABLE:
ILOG("Engine websocket server writeable");
if (firstImagebufferSize > 0 && webSocketWritable == WebSocketState::UNWRITEABLE &&
firstImageBuffer != nullptr) {
ILOG("Send last image after websocket reconnected");
std::lock_guard<std::mutex> guard(WebSocketServer::GetInstance().mutex);
lws_write(wsi,
firstImageBuffer + LWS_PRE,
firstImagebufferSize,
LWS_WRITE_BINARY);
}
webSocketWritable = WebSocketState::WRITEABLE;
break;
case LWS_CALLBACK_CLOSED:
ILOG("Websocket client connection closed");
webSocketWritable = WebSocketState::UNWRITEABLE;
break;
default:
break;
}
return 0;
}
void WebSocketServer::SignalHandler(int sig)
{
interrupted = true;
}
void WebSocketServer::StartWebsocketListening()
{
const auto sig = signal(SIGINT, SignalHandler);
if (sig == SIG_ERR) {
ELOG("StartWebsocketListening failed");
return;
}
ILOG("Begin to start websocket listening!");
struct lws_context_creation_info contextInfo = {0};
contextInfo.port = serverPort;
contextInfo.iface = serverHostname;
contextInfo.protocols = protocols;
contextInfo.ip_limit_wsi = websocketMaxConn;
contextInfo.options = LWS_SERVER_OPTION_VALIDATE_UTF8;
struct lws_context* context = lws_create_context(&contextInfo);
if (context == nullptr) {
ELOG("WebSocketServer::StartWebsocketListening context memory allocation failed");
return;
}
while (!interrupted) {
if (lws_service(context, WEBSOCKET_SERVER_TIMEOUT)) {
interrupted = true;
}
}
lws_context_destroy(context);
}
void WebSocketServer::Run()
{
serverThread = std::make_unique<std::thread>([this]() {
this->StartWebsocketListening();
});
if (serverThread == nullptr) {
ELOG("WebSocketServer::Start serverThread memory allocation failed");
return;
}
serverThread->detach();
}
size_t WebSocketServer::WriteData(unsigned char* data, size_t length)
{
while (webSocketWritable != WebSocketState::WRITEABLE) {
std::this_thread::sleep_for(std::chrono::milliseconds(1));
}
if (webSocket != nullptr && webSocketWritable == WebSocketState::WRITEABLE) {
return lws_write(webSocket, data, length, LWS_WRITE_BINARY);
}
return 0;
}