#include "net/websockets/websocket_basic_stream.h"
#include <stddef.h>
#include <stdint.h>
#include <string.h>
#include <utility>
#include "base/big_endian.h"
#include "base/containers/span.h"
#include "base/time/time.h"
#include "net/base/io_buffer.h"
#include "net/base/privacy_mode.h"
#include "net/base/test_completion_callback.h"
#include "net/dns/public/secure_dns_policy.h"
#include "net/log/test_net_log.h"
#include "net/socket/connect_job.h"
#include "net/socket/socket_tag.h"
#include "net/socket/socket_test_util.h"
#include "net/socket/ssl_client_socket.h"
#include "net/test/gtest_util.h"
#include "net/test/test_with_task_environment.h"
#include "testing/gmock/include/gmock/gmock.h"
#include "testing/gtest/include/gtest/gtest.h"
#include "third_party/abseil-cpp/absl/types/optional.h"
#include "url/scheme_host_port.h"
#include "url/url_constants.h"
using net::test::IsError;
using net::test::IsOk;
namespace net {
namespace {
#define WEBSOCKET_BASIC_STREAM_TEST_DEFINE_CONSTANT(name, value) \
const char k##name[] = value; \
const size_t k##name##Size = std::size(k##name) - 1
WEBSOCKET_BASIC_STREAM_TEST_DEFINE_CONSTANT(SampleFrame, "\x81\x06Sample");
WEBSOCKET_BASIC_STREAM_TEST_DEFINE_CONSTANT(
PartialLargeFrame,
"\x81\x7F\x00\x00\x00\x00\x7F\xFF\xFF\xFF"
"chromiunum ad pasco per loca insanis pullum manducat frumenti");
const size_t kLargeFrameHeaderSize = 10;
WEBSOCKET_BASIC_STREAM_TEST_DEFINE_CONSTANT(MultipleFrames,
"\x81\x01X\x81\x01Y\x81\x01Z");
WEBSOCKET_BASIC_STREAM_TEST_DEFINE_CONSTANT(EmptyFirstFrame, "\x01\x00");
WEBSOCKET_BASIC_STREAM_TEST_DEFINE_CONSTANT(EmptyMiddleFrame, "\x00\x00");
WEBSOCKET_BASIC_STREAM_TEST_DEFINE_CONSTANT(EmptyFinalTextFrame, "\x81\x00");
WEBSOCKET_BASIC_STREAM_TEST_DEFINE_CONSTANT(EmptyFinalContinuationFrame,
"\x80\x00");
WEBSOCKET_BASIC_STREAM_TEST_DEFINE_CONSTANT(ValidPong, "\x8A\x00");
WEBSOCKET_BASIC_STREAM_TEST_DEFINE_CONSTANT(InvalidFrame,
"\x81\x7E\x00\x07Invalid");
WEBSOCKET_BASIC_STREAM_TEST_DEFINE_CONSTANT(PingFrameWithoutFin, "\x09\x00");
WEBSOCKET_BASIC_STREAM_TEST_DEFINE_CONSTANT(
126BytePong,
"\x8a\x7e\x00\x7eZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZ"
"ZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZ");
WEBSOCKET_BASIC_STREAM_TEST_DEFINE_CONSTANT(CloseFrame,
"\x88\x09\x03\xe8occludo");
WEBSOCKET_BASIC_STREAM_TEST_DEFINE_CONSTANT(WriteFrame,
"\x81\x85\x00\x00\x00\x00Write");
WEBSOCKET_BASIC_STREAM_TEST_DEFINE_CONSTANT(MaskedEmptyPong,
"\x8A\x80\x00\x00\x00\x00");
const WebSocketMaskingKey kNulMaskingKey = {{'\0', '\0', '\0', '\0'}};
const WebSocketMaskingKey kNonNulMaskingKey = {
{'\x0d', '\x1b', '\x06', '\x17'}};
WebSocketMaskingKey GenerateNulMaskingKey() { return kNulMaskingKey; }
WebSocketMaskingKey GenerateNonNulMaskingKey() { return kNonNulMaskingKey; }
class StrictStaticSocketDataProvider : public StaticSocketDataProvider {
public:
StrictStaticSocketDataProvider(base::span<const MockRead> reads,
base::span<const MockWrite> writes,
bool strict_mode)
: StaticSocketDataProvider(reads, writes), strict_mode_(strict_mode) {}
~StrictStaticSocketDataProvider() override {
if (strict_mode_) {
EXPECT_EQ(read_count(), read_index());
EXPECT_EQ(write_count(), write_index());
}
}
private:
const bool strict_mode_;
};
class WebSocketBasicStreamSocketTest : public TestWithTaskEnvironment {
protected:
WebSocketBasicStreamSocketTest()
: common_connect_job_params_(
&factory_,
nullptr ,
nullptr ,
nullptr ,
nullptr ,
nullptr ,
nullptr ,
nullptr ,
nullptr ,
nullptr ,
nullptr ,
nullptr ,
nullptr ,
nullptr ),
pool_(1, 1, &common_connect_job_params_),
generator_(&GenerateNulMaskingKey) {}
~WebSocketBasicStreamSocketTest() override {
stream_.reset();
}
std::unique_ptr<ClientSocketHandle> MakeTransportSocket(
base::span<const MockRead> reads,
base::span<const MockWrite> writes) {
socket_data_ = std::make_unique<StrictStaticSocketDataProvider>(
reads, writes, expect_all_io_to_complete_);
socket_data_->set_connect_data(MockConnect(SYNCHRONOUS, OK));
factory_.AddSocketDataProvider(socket_data_.get());
auto transport_socket = std::make_unique<ClientSocketHandle>();
scoped_refptr<ClientSocketPool::SocketParams> null_params;
ClientSocketPool::GroupId group_id(
url::SchemeHostPort(url::kHttpScheme, "a", 80),
PrivacyMode::PRIVACY_MODE_DISABLED, NetworkAnonymizationKey(),
SecureDnsPolicy::kAllow);
transport_socket->Init(
group_id, null_params, absl::nullopt , MEDIUM,
SocketTag(), ClientSocketPool::RespectLimits::ENABLED,
CompletionOnceCallback(), ClientSocketPool::ProxyAuthCallback(), &pool_,
NetLogWithSource());
return transport_socket;
}
void SetHttpReadBuffer(const char* data, size_t size) {
http_read_buffer_ = base::MakeRefCounted<GrowableIOBuffer>();
http_read_buffer_->SetCapacity(size);
memcpy(http_read_buffer_->data(), data, size);
http_read_buffer_->set_offset(size);
}
void CreateStream(base::span<const MockRead> reads,
base::span<const MockWrite> writes) {
stream_ = WebSocketBasicStream::CreateWebSocketBasicStreamForTesting(
MakeTransportSocket(reads, writes), http_read_buffer_, sub_protocol_,
extensions_, net_log_, generator_);
}
std::unique_ptr<SocketDataProvider> socket_data_;
MockClientSocketFactory factory_;
const CommonConnectJobParams common_connect_job_params_;
MockTransportClientSocketPool pool_;
std::vector<std::unique_ptr<WebSocketFrame>> frames_;
TestCompletionCallback cb_;
scoped_refptr<GrowableIOBuffer> http_read_buffer_;
std::string sub_protocol_;
std::string extensions_;
NetLogWithSource net_log_;
WebSocketBasicStream::WebSocketMaskingKeyGeneratorFunction generator_;
bool expect_all_io_to_complete_ = true;
std::unique_ptr<WebSocketBasicStream> stream_;
};
class WebSocketBasicStreamSocketSingleReadTest
: public WebSocketBasicStreamSocketTest {
protected:
void CreateRead(const MockRead& read) {
reads_[0] = read;
CreateStream(reads_, base::span<MockWrite>());
}
MockRead reads_[1];
};
class WebSocketBasicStreamSocketChunkedReadTest
: public WebSocketBasicStreamSocketTest {
protected:
enum LastFrameBehaviour {
LAST_FRAME_BIG,
LAST_FRAME_NOT_BIG
};
void CreateChunkedRead(IoMode mode,
const char data[],
size_t data_size,
int chunk_size,
size_t number_of_chunks,
LastFrameBehaviour last_frame_behaviour) {
reads_.clear();
const char* start = data;
for (size_t i = 0; i < number_of_chunks; ++i) {
int len = chunk_size;
const bool is_last_chunk = (i == number_of_chunks - 1);
if ((last_frame_behaviour == LAST_FRAME_BIG && is_last_chunk) ||
static_cast<int>(data + data_size - start) < len) {
len = static_cast<int>(data + data_size - start);
}
reads_.emplace_back(mode, start, len);
start += len;
}
CreateStream(reads_, base::span<MockWrite>());
}
std::vector<MockRead> reads_;
};
class WebSocketBasicStreamSocketWriteTest
: public WebSocketBasicStreamSocketTest {
protected:
void SetUp() override { PrepareWriteFrame(); }
void PrepareWriteFrame() {
auto frame =
std::make_unique<WebSocketFrame>(WebSocketFrameHeader::kOpCodeText);
const size_t payload_size =
kWriteFrameSize - (WebSocketFrameHeader::kBaseHeaderSize +
WebSocketFrameHeader::kMaskingKeyLength);
auto buffer = base::MakeRefCounted<IOBuffer>(payload_size);
frame_buffers_.push_back(buffer);
memcpy(buffer->data(), kWriteFrame + kWriteFrameSize - payload_size,
payload_size);
frame->payload = buffer->data();
WebSocketFrameHeader& header = frame->header;
header.final = true;
header.masked = true;
header.payload_length = payload_size;
frames_.push_back(std::move(frame));
}
std::vector<scoped_refptr<IOBuffer>> frame_buffers_;
};
class WebSocketBasicStreamSwitchTest : public WebSocketBasicStreamSocketTest {
protected:
base::TimeTicks MicrosecondsFromStart(int microseconds) {
static const base::TimeTicks kStartPoint =
base::TimeTicks::UnixEpoch() + base::Seconds(60);
return kStartPoint + base::Microseconds(microseconds);
}
WebSocketBasicStream::BufferSizeManager buffer_size_manager_;
};
TEST_F(WebSocketBasicStreamSocketTest, ConstructionWorks) {
CreateStream(base::span<MockRead>(), base::span<MockWrite>());
}
TEST_F(WebSocketBasicStreamSocketSingleReadTest, SyncReadWorks) {
CreateRead(MockRead(SYNCHRONOUS, kSampleFrame, kSampleFrameSize));
int result = stream_->ReadFrames(&frames_, cb_.callback());
EXPECT_THAT(result, IsOk());
ASSERT_EQ(1U, frames_.size());
EXPECT_EQ(UINT64_C(6), frames_[0]->header.payload_length);
EXPECT_TRUE(frames_[0]->header.final);
}
TEST_F(WebSocketBasicStreamSocketSingleReadTest, AsyncReadWorks) {
CreateRead(MockRead(ASYNC, kSampleFrame, kSampleFrameSize));
int result = stream_->ReadFrames(&frames_, cb_.callback());
ASSERT_THAT(result, IsError(ERR_IO_PENDING));
EXPECT_THAT(cb_.WaitForResult(), IsOk());
ASSERT_EQ(1U, frames_.size());
EXPECT_EQ(UINT64_C(6), frames_[0]->header.payload_length);
}
TEST_F(WebSocketBasicStreamSocketChunkedReadTest, HeaderFragmentedSync) {
CreateChunkedRead(
SYNCHRONOUS, kSampleFrame, kSampleFrameSize, 1, 2, LAST_FRAME_BIG);
int result = stream_->ReadFrames(&frames_, cb_.callback());
EXPECT_THAT(result, IsOk());
ASSERT_EQ(1U, frames_.size());
EXPECT_EQ(UINT64_C(6), frames_[0]->header.payload_length);
}
TEST_F(WebSocketBasicStreamSocketChunkedReadTest, HeaderFragmentedAsync) {
CreateChunkedRead(
ASYNC, kSampleFrame, kSampleFrameSize, 1, 2, LAST_FRAME_BIG);
int result = stream_->ReadFrames(&frames_, cb_.callback());
ASSERT_THAT(result, IsError(ERR_IO_PENDING));
EXPECT_THAT(cb_.WaitForResult(), IsOk());
ASSERT_EQ(1U, frames_.size());
EXPECT_EQ(UINT64_C(6), frames_[0]->header.payload_length);
}
TEST_F(WebSocketBasicStreamSocketTest, HeaderFragmentedSyncAsync) {
MockRead reads[] = {MockRead(SYNCHRONOUS, kSampleFrame, 1),
MockRead(ASYNC, kSampleFrame + 1, kSampleFrameSize - 1)};
CreateStream(reads, base::span<MockWrite>());
int result = stream_->ReadFrames(&frames_, cb_.callback());
ASSERT_THAT(result, IsError(ERR_IO_PENDING));
EXPECT_THAT(cb_.WaitForResult(), IsOk());
ASSERT_EQ(1U, frames_.size());
EXPECT_EQ(UINT64_C(6), frames_[0]->header.payload_length);
}
TEST_F(WebSocketBasicStreamSocketTest, FragmentedLargeHeader) {
MockRead reads[] = {
MockRead(SYNCHRONOUS, kPartialLargeFrame, kLargeFrameHeaderSize - 1),
MockRead(SYNCHRONOUS, ERR_IO_PENDING)};
CreateStream(reads, base::span<MockWrite>());
EXPECT_THAT(stream_->ReadFrames(&frames_, cb_.callback()),
IsError(ERR_IO_PENDING));
}
TEST_F(WebSocketBasicStreamSocketSingleReadTest, LargeFrameFirstChunk) {
CreateRead(MockRead(SYNCHRONOUS, kPartialLargeFrame, kPartialLargeFrameSize));
EXPECT_THAT(stream_->ReadFrames(&frames_, cb_.callback()), IsOk());
ASSERT_EQ(1U, frames_.size());
EXPECT_FALSE(frames_[0]->header.final);
EXPECT_EQ(kPartialLargeFrameSize - kLargeFrameHeaderSize,
static_cast<size_t>(frames_[0]->header.payload_length));
}
TEST_F(WebSocketBasicStreamSocketSingleReadTest, HeaderOnlyChunk) {
CreateRead(MockRead(SYNCHRONOUS, kPartialLargeFrame, kLargeFrameHeaderSize));
EXPECT_THAT(stream_->ReadFrames(&frames_, cb_.callback()), IsOk());
ASSERT_EQ(1U, frames_.size());
EXPECT_EQ(nullptr, frames_[0]->payload);
EXPECT_EQ(0U, frames_[0]->header.payload_length);
EXPECT_EQ(WebSocketFrameHeader::kOpCodeText, frames_[0]->header.opcode);
}
TEST_F(WebSocketBasicStreamSocketTest, HeaderBodySeparated) {
MockRead reads[] = {
MockRead(SYNCHRONOUS, kPartialLargeFrame, kLargeFrameHeaderSize),
MockRead(ASYNC,
kPartialLargeFrame + kLargeFrameHeaderSize,
kPartialLargeFrameSize - kLargeFrameHeaderSize)};
CreateStream(reads, base::span<MockWrite>());
EXPECT_THAT(stream_->ReadFrames(&frames_, cb_.callback()), IsOk());
ASSERT_EQ(1U, frames_.size());
EXPECT_EQ(nullptr, frames_[0]->payload);
EXPECT_EQ(WebSocketFrameHeader::kOpCodeText, frames_[0]->header.opcode);
frames_.clear();
EXPECT_THAT(stream_->ReadFrames(&frames_, cb_.callback()),
IsError(ERR_IO_PENDING));
EXPECT_THAT(cb_.WaitForResult(), IsOk());
ASSERT_EQ(1U, frames_.size());
EXPECT_EQ(kPartialLargeFrameSize - kLargeFrameHeaderSize,
frames_[0]->header.payload_length);
EXPECT_EQ(WebSocketFrameHeader::kOpCodeContinuation,
frames_[0]->header.opcode);
}
TEST_F(WebSocketBasicStreamSocketChunkedReadTest, LargeFrameTwoChunks) {
const size_t kChunkSize = 16;
CreateChunkedRead(ASYNC,
kPartialLargeFrame,
kPartialLargeFrameSize,
kChunkSize,
2,
LAST_FRAME_NOT_BIG);
TestCompletionCallback cb[2];
ASSERT_THAT(stream_->ReadFrames(&frames_, cb[0].callback()),
IsError(ERR_IO_PENDING));
EXPECT_THAT(cb[0].WaitForResult(), IsOk());
ASSERT_EQ(1U, frames_.size());
EXPECT_EQ(kChunkSize - kLargeFrameHeaderSize,
frames_[0]->header.payload_length);
frames_.clear();
ASSERT_THAT(stream_->ReadFrames(&frames_, cb[1].callback()),
IsError(ERR_IO_PENDING));
EXPECT_THAT(cb[1].WaitForResult(), IsOk());
ASSERT_EQ(1U, frames_.size());
EXPECT_EQ(kChunkSize, frames_[0]->header.payload_length);
}
TEST_F(WebSocketBasicStreamSocketChunkedReadTest, OnlyFinalChunkIsFinal) {
static const size_t kFirstChunkSize = 4;
CreateChunkedRead(ASYNC,
kSampleFrame,
kSampleFrameSize,
kFirstChunkSize,
2,
LAST_FRAME_BIG);
TestCompletionCallback cb[2];
ASSERT_THAT(stream_->ReadFrames(&frames_, cb[0].callback()),
IsError(ERR_IO_PENDING));
EXPECT_THAT(cb[0].WaitForResult(), IsOk());
ASSERT_EQ(1U, frames_.size());
ASSERT_FALSE(frames_[0]->header.final);
frames_.clear();
ASSERT_THAT(stream_->ReadFrames(&frames_, cb[1].callback()),
IsError(ERR_IO_PENDING));
EXPECT_THAT(cb[1].WaitForResult(), IsOk());
ASSERT_EQ(1U, frames_.size());
ASSERT_TRUE(frames_[0]->header.final);
}
TEST_F(WebSocketBasicStreamSocketChunkedReadTest, ContinuationOpCodeUsed) {
const size_t kFirstChunkSize = 3;
const int kChunkCount = 3;
CreateChunkedRead(ASYNC,
kSampleFrame,
kSampleFrameSize,
kFirstChunkSize,
kChunkCount,
LAST_FRAME_BIG);
TestCompletionCallback cb[kChunkCount];
ASSERT_THAT(stream_->ReadFrames(&frames_, cb[0].callback()),
IsError(ERR_IO_PENDING));
EXPECT_THAT(cb[0].WaitForResult(), IsOk());
ASSERT_EQ(1U, frames_.size());
EXPECT_EQ(WebSocketFrameHeader::kOpCodeText, frames_[0]->header.opcode);
for (int i = 1; i < kChunkCount; ++i) {
frames_.clear();
ASSERT_THAT(stream_->ReadFrames(&frames_, cb[i].callback()),
IsError(ERR_IO_PENDING));
EXPECT_THAT(cb[i].WaitForResult(), IsOk());
ASSERT_EQ(1U, frames_.size());
EXPECT_EQ(WebSocketFrameHeader::kOpCodeContinuation,
frames_[0]->header.opcode);
}
}
TEST_F(WebSocketBasicStreamSocketSingleReadTest, ThreeFramesTogether) {
CreateRead(MockRead(SYNCHRONOUS, kMultipleFrames, kMultipleFramesSize));
EXPECT_THAT(stream_->ReadFrames(&frames_, cb_.callback()), IsOk());
ASSERT_EQ(3U, frames_.size());
EXPECT_TRUE(frames_[0]->header.final);
EXPECT_TRUE(frames_[1]->header.final);
EXPECT_TRUE(frames_[2]->header.final);
}
TEST_F(WebSocketBasicStreamSocketSingleReadTest, SyncClose) {
CreateRead(MockRead(SYNCHRONOUS, "", 0));
EXPECT_EQ(ERR_CONNECTION_CLOSED,
stream_->ReadFrames(&frames_, cb_.callback()));
}
TEST_F(WebSocketBasicStreamSocketSingleReadTest, AsyncClose) {
CreateRead(MockRead(ASYNC, "", 0));
ASSERT_THAT(stream_->ReadFrames(&frames_, cb_.callback()),
IsError(ERR_IO_PENDING));
EXPECT_THAT(cb_.WaitForResult(), IsError(ERR_CONNECTION_CLOSED));
}
TEST_F(WebSocketBasicStreamSocketSingleReadTest, SyncCloseWithErr) {
CreateRead(MockRead(SYNCHRONOUS, ERR_CONNECTION_CLOSED));
EXPECT_EQ(ERR_CONNECTION_CLOSED,
stream_->ReadFrames(&frames_, cb_.callback()));
}
TEST_F(WebSocketBasicStreamSocketSingleReadTest, AsyncCloseWithErr) {
CreateRead(MockRead(ASYNC, ERR_CONNECTION_CLOSED));
ASSERT_THAT(stream_->ReadFrames(&frames_, cb_.callback()),
IsError(ERR_IO_PENDING));
EXPECT_THAT(cb_.WaitForResult(), IsError(ERR_CONNECTION_CLOSED));
}
TEST_F(WebSocketBasicStreamSocketSingleReadTest, SyncErrorsPassedThrough) {
CreateRead(MockRead(SYNCHRONOUS, ERR_INSUFFICIENT_RESOURCES));
EXPECT_EQ(ERR_INSUFFICIENT_RESOURCES,
stream_->ReadFrames(&frames_, cb_.callback()));
}
TEST_F(WebSocketBasicStreamSocketSingleReadTest, AsyncErrorsPassedThrough) {
CreateRead(MockRead(ASYNC, ERR_INSUFFICIENT_RESOURCES));
ASSERT_THAT(stream_->ReadFrames(&frames_, cb_.callback()),
IsError(ERR_IO_PENDING));
EXPECT_THAT(cb_.WaitForResult(), IsError(ERR_INSUFFICIENT_RESOURCES));
}
TEST_F(WebSocketBasicStreamSocketChunkedReadTest, CloseAfterFrame) {
CreateChunkedRead(SYNCHRONOUS,
kSampleFrame,
kSampleFrameSize,
kSampleFrameSize,
2,
LAST_FRAME_NOT_BIG);
EXPECT_THAT(stream_->ReadFrames(&frames_, cb_.callback()), IsOk());
EXPECT_EQ(1U, frames_.size());
frames_.clear();
EXPECT_EQ(ERR_CONNECTION_CLOSED,
stream_->ReadFrames(&frames_, cb_.callback()));
}
TEST_F(WebSocketBasicStreamSocketTest, AsyncCloseAfterIncompleteHeader) {
MockRead reads[] = {MockRead(ASYNC, kSampleFrame, 1U),
MockRead(SYNCHRONOUS, "", 0)};
CreateStream(reads, base::span<MockWrite>());
ASSERT_THAT(stream_->ReadFrames(&frames_, cb_.callback()),
IsError(ERR_IO_PENDING));
EXPECT_THAT(cb_.WaitForResult(), IsError(ERR_CONNECTION_CLOSED));
}
TEST_F(WebSocketBasicStreamSocketTest, AsyncErrCloseAfterIncompleteHeader) {
MockRead reads[] = {MockRead(ASYNC, kSampleFrame, 1U),
MockRead(SYNCHRONOUS, ERR_CONNECTION_CLOSED)};
CreateStream(reads, base::span<MockWrite>());
ASSERT_THAT(stream_->ReadFrames(&frames_, cb_.callback()),
IsError(ERR_IO_PENDING));
EXPECT_THAT(cb_.WaitForResult(), IsError(ERR_CONNECTION_CLOSED));
}
TEST_F(WebSocketBasicStreamSocketSingleReadTest, EmptyFirstFrame) {
CreateRead(MockRead(SYNCHRONOUS, kEmptyFirstFrame, kEmptyFirstFrameSize));
EXPECT_THAT(stream_->ReadFrames(&frames_, cb_.callback()), IsOk());
ASSERT_EQ(1U, frames_.size());
EXPECT_EQ(nullptr, frames_[0]->payload);
EXPECT_EQ(0U, frames_[0]->header.payload_length);
}
TEST_F(WebSocketBasicStreamSocketTest, EmptyMiddleFrame) {
MockRead reads[] = {
MockRead(SYNCHRONOUS, kEmptyFirstFrame, kEmptyFirstFrameSize),
MockRead(SYNCHRONOUS, kEmptyMiddleFrame, kEmptyMiddleFrameSize),
MockRead(SYNCHRONOUS, ERR_IO_PENDING)};
CreateStream(reads, base::span<MockWrite>());
EXPECT_THAT(stream_->ReadFrames(&frames_, cb_.callback()), IsOk());
EXPECT_EQ(1U, frames_.size());
frames_.clear();
EXPECT_THAT(stream_->ReadFrames(&frames_, cb_.callback()),
IsError(ERR_IO_PENDING));
}
TEST_F(WebSocketBasicStreamSocketTest, EmptyMiddleFrameAsync) {
MockRead reads[] = {
MockRead(SYNCHRONOUS, kEmptyFirstFrame, kEmptyFirstFrameSize),
MockRead(ASYNC, kEmptyMiddleFrame, kEmptyMiddleFrameSize),
MockRead(ASYNC, kValidPong, kValidPongSize)};
CreateStream(reads, base::span<MockWrite>());
EXPECT_THAT(stream_->ReadFrames(&frames_, cb_.callback()), IsOk());
EXPECT_EQ(1U, frames_.size());
frames_.clear();
ASSERT_THAT(stream_->ReadFrames(&frames_, cb_.callback()),
IsError(ERR_IO_PENDING));
EXPECT_THAT(cb_.WaitForResult(), IsOk());
ASSERT_EQ(1U, frames_.size());
EXPECT_EQ(WebSocketFrameHeader::kOpCodePong, frames_[0]->header.opcode);
}
TEST_F(WebSocketBasicStreamSocketSingleReadTest, EmptyFinalFrame) {
CreateRead(
MockRead(SYNCHRONOUS, kEmptyFinalTextFrame, kEmptyFinalTextFrameSize));
EXPECT_THAT(stream_->ReadFrames(&frames_, cb_.callback()), IsOk());
ASSERT_EQ(1U, frames_.size());
EXPECT_EQ(nullptr, frames_[0]->payload);
EXPECT_EQ(0U, frames_[0]->header.payload_length);
}
TEST_F(WebSocketBasicStreamSocketTest, ThreeFrameEmptyMessage) {
MockRead reads[] = {
MockRead(SYNCHRONOUS, kEmptyFirstFrame, kEmptyFirstFrameSize),
MockRead(SYNCHRONOUS, kEmptyMiddleFrame, kEmptyMiddleFrameSize),
MockRead(SYNCHRONOUS,
kEmptyFinalContinuationFrame,
kEmptyFinalContinuationFrameSize)};
CreateStream(reads, base::span<MockWrite>());
EXPECT_THAT(stream_->ReadFrames(&frames_, cb_.callback()), IsOk());
ASSERT_EQ(1U, frames_.size());
EXPECT_EQ(WebSocketFrameHeader::kOpCodeText, frames_[0]->header.opcode);
frames_.clear();
EXPECT_THAT(stream_->ReadFrames(&frames_, cb_.callback()), IsOk());
ASSERT_EQ(1U, frames_.size());
EXPECT_TRUE(frames_[0]->header.final);
}
TEST_F(WebSocketBasicStreamSocketTest, HttpReadBufferIsUsed) {
SetHttpReadBuffer(kSampleFrame, kSampleFrameSize);
CreateStream(base::span<MockRead>(), base::span<MockWrite>());
EXPECT_THAT(stream_->ReadFrames(&frames_, cb_.callback()), IsOk());
ASSERT_EQ(1U, frames_.size());
ASSERT_TRUE(frames_[0]->payload);
EXPECT_EQ(UINT64_C(6), frames_[0]->header.payload_length);
}
TEST_F(WebSocketBasicStreamSocketSingleReadTest,
PartialFrameHeaderInHttpResponse) {
SetHttpReadBuffer(kSampleFrame, 1);
CreateRead(MockRead(ASYNC, kSampleFrame + 1, kSampleFrameSize - 1));
ASSERT_THAT(stream_->ReadFrames(&frames_, cb_.callback()),
IsError(ERR_IO_PENDING));
EXPECT_THAT(cb_.WaitForResult(), IsOk());
ASSERT_EQ(1U, frames_.size());
ASSERT_TRUE(frames_[0]->payload);
EXPECT_EQ(UINT64_C(6), frames_[0]->header.payload_length);
EXPECT_EQ(WebSocketFrameHeader::kOpCodeText, frames_[0]->header.opcode);
}
TEST_F(WebSocketBasicStreamSocketSingleReadTest,
PartialControlFrameInHttpResponse) {
const size_t kPartialFrameBytes = 3;
SetHttpReadBuffer(kCloseFrame, kPartialFrameBytes);
CreateRead(MockRead(ASYNC,
kCloseFrame + kPartialFrameBytes,
kCloseFrameSize - kPartialFrameBytes));
ASSERT_THAT(stream_->ReadFrames(&frames_, cb_.callback()),
IsError(ERR_IO_PENDING));
EXPECT_THAT(cb_.WaitForResult(), IsOk());
ASSERT_EQ(1U, frames_.size());
EXPECT_EQ(WebSocketFrameHeader::kOpCodeClose, frames_[0]->header.opcode);
EXPECT_EQ(kCloseFrameSize - 2, frames_[0]->header.payload_length);
EXPECT_EQ(std::string(frames_[0]->payload, kCloseFrameSize - 2),
std::string(kCloseFrame + 2, kCloseFrameSize - 2));
}
TEST_F(WebSocketBasicStreamSocketSingleReadTest,
PartialControlFrameInHttpResponseSync) {
const size_t kPartialFrameBytes = 3;
SetHttpReadBuffer(kCloseFrame, kPartialFrameBytes);
CreateRead(MockRead(SYNCHRONOUS,
kCloseFrame + kPartialFrameBytes,
kCloseFrameSize - kPartialFrameBytes));
EXPECT_THAT(stream_->ReadFrames(&frames_, cb_.callback()), IsOk());
ASSERT_EQ(1U, frames_.size());
EXPECT_EQ(WebSocketFrameHeader::kOpCodeClose, frames_[0]->header.opcode);
}
TEST_F(WebSocketBasicStreamSocketSingleReadTest, SyncInvalidFrame) {
CreateRead(MockRead(SYNCHRONOUS, kInvalidFrame, kInvalidFrameSize));
EXPECT_EQ(ERR_WS_PROTOCOL_ERROR,
stream_->ReadFrames(&frames_, cb_.callback()));
}
TEST_F(WebSocketBasicStreamSocketSingleReadTest, AsyncInvalidFrame) {
CreateRead(MockRead(ASYNC, kInvalidFrame, kInvalidFrameSize));
ASSERT_THAT(stream_->ReadFrames(&frames_, cb_.callback()),
IsError(ERR_IO_PENDING));
EXPECT_THAT(cb_.WaitForResult(), IsError(ERR_WS_PROTOCOL_ERROR));
}
TEST_F(WebSocketBasicStreamSocketSingleReadTest, ControlFrameWithoutFin) {
CreateRead(
MockRead(SYNCHRONOUS, kPingFrameWithoutFin, kPingFrameWithoutFinSize));
EXPECT_EQ(ERR_WS_PROTOCOL_ERROR,
stream_->ReadFrames(&frames_, cb_.callback()));
EXPECT_TRUE(frames_.empty());
}
TEST_F(WebSocketBasicStreamSocketSingleReadTest, OverlongControlFrame) {
CreateRead(MockRead(SYNCHRONOUS, k126BytePong, k126BytePongSize));
EXPECT_EQ(ERR_WS_PROTOCOL_ERROR,
stream_->ReadFrames(&frames_, cb_.callback()));
EXPECT_TRUE(frames_.empty());
}
TEST_F(WebSocketBasicStreamSocketChunkedReadTest, SplitOverlongControlFrame) {
const size_t kFirstChunkSize = 16;
expect_all_io_to_complete_ = false;
CreateChunkedRead(SYNCHRONOUS,
k126BytePong,
k126BytePongSize,
kFirstChunkSize,
2,
LAST_FRAME_BIG);
EXPECT_EQ(ERR_WS_PROTOCOL_ERROR,
stream_->ReadFrames(&frames_, cb_.callback()));
EXPECT_TRUE(frames_.empty());
}
TEST_F(WebSocketBasicStreamSocketChunkedReadTest,
AsyncSplitOverlongControlFrame) {
const size_t kFirstChunkSize = 16;
expect_all_io_to_complete_ = false;
CreateChunkedRead(ASYNC,
k126BytePong,
k126BytePongSize,
kFirstChunkSize,
2,
LAST_FRAME_BIG);
ASSERT_THAT(stream_->ReadFrames(&frames_, cb_.callback()),
IsError(ERR_IO_PENDING));
EXPECT_THAT(cb_.WaitForResult(), IsError(ERR_WS_PROTOCOL_ERROR));
EXPECT_TRUE(frames_.empty());
}
TEST_F(WebSocketBasicStreamSocketChunkedReadTest, SyncControlFrameAssembly) {
const size_t kChunkSize = 3;
CreateChunkedRead(
SYNCHRONOUS, kCloseFrame, kCloseFrameSize, kChunkSize, 3, LAST_FRAME_BIG);
EXPECT_THAT(stream_->ReadFrames(&frames_, cb_.callback()), IsOk());
ASSERT_EQ(1U, frames_.size());
EXPECT_EQ(WebSocketFrameHeader::kOpCodeClose, frames_[0]->header.opcode);
}
TEST_F(WebSocketBasicStreamSocketChunkedReadTest, AsyncControlFrameAssembly) {
const size_t kChunkSize = 3;
CreateChunkedRead(
ASYNC, kCloseFrame, kCloseFrameSize, kChunkSize, 3, LAST_FRAME_BIG);
ASSERT_THAT(stream_->ReadFrames(&frames_, cb_.callback()),
IsError(ERR_IO_PENDING));
EXPECT_THAT(cb_.WaitForResult(), IsOk());
ASSERT_EQ(1U, frames_.size());
EXPECT_EQ(WebSocketFrameHeader::kOpCodeClose, frames_[0]->header.opcode);
}
TEST_F(WebSocketBasicStreamSocketChunkedReadTest, OneMegFrame) {
const int kReadBufferSize = 1000;
const uint64_t kPayloadSize = 1 << 20;
const size_t kWireSize = kPayloadSize + kLargeFrameHeaderSize;
const size_t kExpectedFrameCount =
(kWireSize + kReadBufferSize - 1) / kReadBufferSize;
auto big_frame = std::make_unique<char[]>(kWireSize);
memcpy(big_frame.get(), "\x81\x7F", 2);
base::WriteBigEndian(big_frame.get() + 2, kPayloadSize);
memset(big_frame.get() + kLargeFrameHeaderSize, 'A', kPayloadSize);
CreateChunkedRead(ASYNC,
big_frame.get(),
kWireSize,
kReadBufferSize,
kExpectedFrameCount,
LAST_FRAME_BIG);
for (size_t frame = 0; frame < kExpectedFrameCount; ++frame) {
frames_.clear();
ASSERT_THAT(stream_->ReadFrames(&frames_, cb_.callback()),
IsError(ERR_IO_PENDING));
EXPECT_THAT(cb_.WaitForResult(), IsOk());
ASSERT_EQ(1U, frames_.size());
size_t expected_payload_size = kReadBufferSize;
if (frame == 0) {
expected_payload_size = kReadBufferSize - kLargeFrameHeaderSize;
} else if (frame == kExpectedFrameCount - 1) {
expected_payload_size =
kWireSize - kReadBufferSize * (kExpectedFrameCount - 1);
}
EXPECT_EQ(expected_payload_size, frames_[0]->header.payload_length);
}
}
TEST_F(WebSocketBasicStreamSocketChunkedReadTest, ReservedFlagCleared) {
static const char kReservedFlagFrame[] = "\x41\x05Hello";
const size_t kReservedFlagFrameSize = std::size(kReservedFlagFrame) - 1;
const size_t kChunkSize = 5;
CreateChunkedRead(ASYNC,
kReservedFlagFrame,
kReservedFlagFrameSize,
kChunkSize,
2,
LAST_FRAME_BIG);
TestCompletionCallback cb[2];
ASSERT_THAT(stream_->ReadFrames(&frames_, cb[0].callback()),
IsError(ERR_IO_PENDING));
EXPECT_THAT(cb[0].WaitForResult(), IsOk());
ASSERT_EQ(1U, frames_.size());
EXPECT_TRUE(frames_[0]->header.reserved1);
frames_.clear();
ASSERT_THAT(stream_->ReadFrames(&frames_, cb[1].callback()),
IsError(ERR_IO_PENDING));
EXPECT_THAT(cb[1].WaitForResult(), IsOk());
ASSERT_EQ(1U, frames_.size());
EXPECT_FALSE(frames_[0]->header.reserved1);
}
TEST_F(WebSocketBasicStreamSocketWriteTest, WriteAtOnce) {
MockWrite writes[] = {MockWrite(SYNCHRONOUS, kWriteFrame, kWriteFrameSize)};
CreateStream(base::span<MockRead>(), writes);
EXPECT_THAT(stream_->WriteFrames(&frames_, cb_.callback()), IsOk());
}
TEST_F(WebSocketBasicStreamSocketWriteTest, AsyncWriteAtOnce) {
MockWrite writes[] = {MockWrite(ASYNC, kWriteFrame, kWriteFrameSize)};
CreateStream(base::span<MockRead>(), writes);
ASSERT_THAT(stream_->WriteFrames(&frames_, cb_.callback()),
IsError(ERR_IO_PENDING));
EXPECT_THAT(cb_.WaitForResult(), IsOk());
}
TEST_F(WebSocketBasicStreamSocketWriteTest, WriteInBits) {
MockWrite writes[] = {MockWrite(SYNCHRONOUS, kWriteFrame, 4),
MockWrite(ASYNC, kWriteFrame + 4, 4),
MockWrite(ASYNC, kWriteFrame + 8, kWriteFrameSize - 8)};
CreateStream(base::span<MockRead>(), writes);
ASSERT_THAT(stream_->WriteFrames(&frames_, cb_.callback()),
IsError(ERR_IO_PENDING));
EXPECT_THAT(cb_.WaitForResult(), IsOk());
}
TEST_F(WebSocketBasicStreamSocketWriteTest, WriteNullptrPong) {
MockWrite writes[] = {
MockWrite(SYNCHRONOUS, kMaskedEmptyPong, kMaskedEmptyPongSize)};
CreateStream(base::span<MockRead>(), writes);
auto frame =
std::make_unique<WebSocketFrame>(WebSocketFrameHeader::kOpCodePong);
WebSocketFrameHeader& header = frame->header;
header.final = true;
header.masked = true;
header.payload_length = 0;
std::vector<std::unique_ptr<WebSocketFrame>> frames;
frames.push_back(std::move(frame));
EXPECT_THAT(stream_->WriteFrames(&frames, cb_.callback()), IsOk());
}
TEST_F(WebSocketBasicStreamSocketTest, WriteNonNulMask) {
std::string masked_frame = std::string("\x81\x88");
masked_frame += std::string(kNonNulMaskingKey.key, 4);
masked_frame += "jiggered";
MockWrite writes[] = {
MockWrite(SYNCHRONOUS, masked_frame.data(), masked_frame.size())};
generator_ = &GenerateNonNulMaskingKey;
CreateStream(base::span<MockRead>(), writes);
auto frame =
std::make_unique<WebSocketFrame>(WebSocketFrameHeader::kOpCodeText);
const std::string unmasked_payload = "graphics";
const size_t payload_size = unmasked_payload.size();
auto buffer = base::MakeRefCounted<IOBuffer>(payload_size);
memcpy(buffer->data(), unmasked_payload.data(), payload_size);
frame->payload = buffer->data();
WebSocketFrameHeader& header = frame->header;
header.final = true;
header.masked = true;
header.payload_length = payload_size;
frames_.push_back(std::move(frame));
EXPECT_THAT(stream_->WriteFrames(&frames_, cb_.callback()), IsOk());
}
TEST_F(WebSocketBasicStreamSocketTest, GetExtensionsWorks) {
extensions_ = "inflate-uuencode";
CreateStream(base::span<MockRead>(), base::span<MockWrite>());
EXPECT_EQ("inflate-uuencode", stream_->GetExtensions());
}
TEST_F(WebSocketBasicStreamSocketTest, GetSubProtocolWorks) {
sub_protocol_ = "cyberchat";
CreateStream(base::span<MockRead>(), base::span<MockWrite>());
EXPECT_EQ("cyberchat", stream_->GetSubProtocol());
}
TEST_F(WebSocketBasicStreamSwitchTest, GetInitialReadBufferSize) {
EXPECT_EQ(buffer_size_manager_.buffer_size(),
WebSocketBasicStream::BufferSize::kSmall);
buffer_size_manager_.OnRead(MicrosecondsFromStart(0));
EXPECT_EQ(buffer_size_manager_.buffer_size(),
WebSocketBasicStream::BufferSize::kSmall);
}
TEST_F(WebSocketBasicStreamSwitchTest, ZeroSecondRead) {
buffer_size_manager_.set_window_for_test(1);
buffer_size_manager_.OnRead(MicrosecondsFromStart(0));
buffer_size_manager_.OnReadComplete(MicrosecondsFromStart(0), 1000);
EXPECT_EQ(buffer_size_manager_.buffer_size(),
WebSocketBasicStream::BufferSize::kLarge);
}
TEST_F(WebSocketBasicStreamSwitchTest, CheckSwitch) {
buffer_size_manager_.set_window_for_test(4);
buffer_size_manager_.OnRead(MicrosecondsFromStart(0));
buffer_size_manager_.OnReadComplete(MicrosecondsFromStart(200), 1000);
buffer_size_manager_.OnRead(MicrosecondsFromStart(800));
buffer_size_manager_.OnReadComplete(MicrosecondsFromStart(1000), 1000);
buffer_size_manager_.OnRead(MicrosecondsFromStart(1300));
buffer_size_manager_.OnReadComplete(MicrosecondsFromStart(1500), 1000);
buffer_size_manager_.OnRead(MicrosecondsFromStart(1800));
buffer_size_manager_.OnReadComplete(MicrosecondsFromStart(2000), 1000);
EXPECT_EQ(buffer_size_manager_.buffer_size(),
WebSocketBasicStream::BufferSize::kLarge);
}
}
}