#include "net/ntlm/ntlm_client.h"
#include <string>
#include "base/compiler_specific.h"
#include "base/containers/span.h"
#include "base/strings/string_util.h"
#include "build/build_config.h"
#include "net/ntlm/ntlm.h"
#include "net/ntlm/ntlm_buffer_reader.h"
#include "net/ntlm/ntlm_buffer_writer.h"
#include "net/ntlm/ntlm_test_data.h"
#include "testing/gtest/include/gtest/gtest.h"
namespace net::ntlm {
namespace {
std::vector<uint8_t> GenerateAuthMsg(const NtlmClient& client,
base::span<const uint8_t> challenge_msg) {
return client.GenerateAuthenticateMessage(
test::kNtlmDomain, test::kUser, test::kPassword, test::kHostnameAscii,
reinterpret_cast<const char*>(test::kChannelBindings), test::kNtlmSpn,
test::kClientTimestamp, test::kClientChallenge, challenge_msg);
}
std::vector<uint8_t> GenerateAuthMsg(const NtlmClient& client,
const NtlmBufferWriter& challenge_writer) {
return GenerateAuthMsg(client, challenge_writer.GetBuffer());
}
bool GetAuthMsgResult(const NtlmClient& client,
const NtlmBufferWriter& challenge_writer) {
return !GenerateAuthMsg(client, challenge_writer).empty();
}
bool ReadBytesFrom(NtlmBufferReader* reader,
const SecurityBuffer& sec_buf,
base::span<uint8_t> buffer) {
CHECK_EQ(sec_buf.length, buffer.size());
NtlmBufferReader portion_reader;
return reader->ReadPayloadAsBufferReader(sec_buf, &portion_reader) &&
portion_reader.ReadBytes(buffer);
}
bool ReadBytesPayload(NtlmBufferReader* reader, base::span<uint8_t> buffer) {
SecurityBuffer sec_buf;
return reader->ReadSecurityBuffer(&sec_buf) &&
(sec_buf.length == buffer.size()) &&
ReadBytesFrom(reader, sec_buf, buffer);
}
bool ReadStringPayload(NtlmBufferReader* reader, std::string* str) {
SecurityBuffer sec_buf;
if (!reader->ReadSecurityBuffer(&sec_buf))
return false;
str->resize(sec_buf.length);
if (!ReadBytesFrom(reader, sec_buf, base::as_writable_byte_span(*str))) {
return false;
}
return true;
}
bool ReadString16Payload(NtlmBufferReader* reader, std::u16string* str) {
SecurityBuffer sec_buf;
if (!reader->ReadSecurityBuffer(&sec_buf) || (sec_buf.length % 2 != 0))
return false;
std::vector<uint8_t> raw(sec_buf.length);
if (!ReadBytesFrom(reader, sec_buf, raw)) {
return false;
}
#if defined(ARCH_CPU_BIG_ENDIAN)
for (size_t i = 0; i < raw.size(); i += 2) {
std::swap(raw[i], raw[i + 1]);
}
#endif
str->resize(raw.size() / 2);
base::as_writable_byte_span(*str).copy_from(raw);
return true;
}
void MakeV2ChallengeMessage(size_t target_info_len, std::vector<uint8_t>* out) {
static const size_t kChallengeV2HeaderLen = 56;
size_t server_name_len = target_info_len - kAvPairHeaderLen * 2;
NtlmBufferWriter challenge(kChallengeV2HeaderLen + target_info_len);
ASSERT_TRUE(challenge.WriteMessageHeader(MessageType::kChallenge));
ASSERT_TRUE(
challenge.WriteSecurityBuffer(SecurityBuffer(0, 0)));
ASSERT_TRUE(challenge.WriteFlags(NegotiateFlags::kTargetInfo));
ASSERT_TRUE(challenge.WriteZeros(kChallengeLen));
ASSERT_TRUE(challenge.WriteZeros(8));
ASSERT_TRUE(challenge.WriteSecurityBuffer(
SecurityBuffer(kChallengeV2HeaderLen, target_info_len)));
ASSERT_TRUE(challenge.WriteZeros(8));
ASSERT_EQ(kChallengeV2HeaderLen, challenge.GetCursor());
ASSERT_TRUE(challenge.WriteAvPair(
AvPair(TargetInfoAvId::kServerName,
std::vector<uint8_t>(server_name_len, 'a'))));
ASSERT_TRUE(challenge.WriteAvPairTerminator());
ASSERT_TRUE(challenge.IsEndOfBuffer());
*out = challenge.Pass();
}
}
TEST(NtlmClientTest, SimpleConstructionV1) {
NtlmClient client(NtlmFeatures(false));
ASSERT_FALSE(client.IsNtlmV2());
ASSERT_FALSE(client.IsEpaEnabled());
ASSERT_FALSE(client.IsMicEnabled());
}
TEST(NtlmClientTest, VerifyNegotiateMessageV1) {
NtlmClient client(NtlmFeatures(false));
std::vector<uint8_t> result = client.GetNegotiateMessage();
ASSERT_EQ(kNegotiateMessageLen, result.size());
ASSERT_EQ(base::span(test::kExpectedNegotiateMsg), base::span(result));
}
TEST(NtlmClientTest, MinimalStructurallyValidChallenge) {
NtlmClient client(NtlmFeatures(false));
NtlmBufferWriter writer(kMinChallengeHeaderLen);
ASSERT_TRUE(writer.WriteBytes(
base::span(test::kMinChallengeMessage).first<kMinChallengeHeaderLen>()));
ASSERT_TRUE(GetAuthMsgResult(client, writer));
}
TEST(NtlmClientTest, MinimalStructurallyValidChallengeZeroOffset) {
NtlmClient client(NtlmFeatures(false));
uint8_t raw[kMinChallengeHeaderLen];
base::span(raw).copy_from(
base::span(test::kMinChallengeMessage).first(kMinChallengeHeaderLen));
ASSERT_NE(0x00, raw[16]);
raw[16] = 0x00;
NtlmBufferWriter writer(kMinChallengeHeaderLen);
ASSERT_TRUE(writer.WriteBytes(raw));
ASSERT_TRUE(GetAuthMsgResult(client, writer));
}
TEST(NtlmClientTest, ChallengeMsgTooShort) {
NtlmClient client(NtlmFeatures(false));
NtlmBufferWriter writer(kMinChallengeHeaderLen - 1);
ASSERT_TRUE(writer.WriteBytes(base::span(test::kMinChallengeMessage)
.first<kMinChallengeHeaderLen - 1>()));
ASSERT_FALSE(GetAuthMsgResult(client, writer));
}
TEST(NtlmClientTest, ChallengeMsgNoSig) {
NtlmClient client(NtlmFeatures(false));
uint8_t raw[kMinChallengeHeaderLen];
base::span(raw).copy_from(
base::span(test::kMinChallengeMessage).first(kMinChallengeHeaderLen));
ASSERT_NE(0xff, raw[7]);
raw[7] = 0xff;
NtlmBufferWriter writer(kMinChallengeHeaderLen);
ASSERT_TRUE(writer.WriteBytes(raw));
ASSERT_FALSE(GetAuthMsgResult(client, writer));
}
TEST(NtlmClientTest, ChallengeMsgWrongMessageType) {
NtlmClient client(NtlmFeatures(false));
uint8_t raw[kMinChallengeHeaderLen];
base::span(raw).copy_from(
base::span(test::kMinChallengeMessage).first(kMinChallengeHeaderLen));
ASSERT_NE(0x03, raw[8]);
raw[8] = 0x03;
NtlmBufferWriter writer(kMinChallengeHeaderLen);
ASSERT_TRUE(writer.WriteBytes(raw));
ASSERT_FALSE(GetAuthMsgResult(client, writer));
}
TEST(NtlmClientTest, ChallengeWithNoTargetName) {
NtlmClient client(NtlmFeatures(false));
uint8_t raw[kMinChallengeHeaderLen];
base::span(raw).copy_from(
base::span(test::kMinChallengeMessage).first(kMinChallengeHeaderLen));
ASSERT_NE(0x00, raw[16]);
raw[16] = 0x00;
NtlmBufferWriter writer(kMinChallengeHeaderLen);
ASSERT_TRUE(writer.WriteBytes(raw));
ASSERT_TRUE(GetAuthMsgResult(client, writer));
}
TEST(NtlmClientTest, Type2MessageWithTargetName) {
NtlmClient client(NtlmFeatures(false));
uint8_t raw[kMinChallengeHeaderLen + 1];
base::span(raw).copy_prefix_from(
base::span(test::kMinChallengeMessage).first(kMinChallengeHeaderLen));
raw[kMinChallengeHeaderLen] = 'Z';
ASSERT_NE(0x01, raw[12]);
ASSERT_EQ(0x00, raw[13]);
ASSERT_NE(0x01, raw[14]);
ASSERT_EQ(0x00, raw[15]);
raw[12] = 0x01;
raw[14] = 0x01;
NtlmBufferWriter writer(kChallengeHeaderLen + 1);
ASSERT_TRUE(writer.WriteBytes(raw));
ASSERT_TRUE(GetAuthMsgResult(client, writer));
}
TEST(NtlmClientTest, NoTargetNameOverflowFromOffset) {
NtlmClient client(NtlmFeatures(false));
uint8_t raw[kMinChallengeHeaderLen];
base::span(raw).copy_from(
base::span(test::kMinChallengeMessage).first(kMinChallengeHeaderLen));
ASSERT_NE(0x01, raw[12]);
ASSERT_EQ(0x00, raw[13]);
ASSERT_NE(0x01, raw[14]);
ASSERT_EQ(0x00, raw[15]);
raw[12] = 0x01;
raw[14] = 0x01;
NtlmBufferWriter writer(kMinChallengeHeaderLen);
ASSERT_TRUE(writer.WriteBytes(raw));
ASSERT_FALSE(GetAuthMsgResult(client, writer));
}
TEST(NtlmClientTest, NoTargetNameOverflowFromLength) {
NtlmClient client(NtlmFeatures(false));
uint8_t raw[kMinChallengeHeaderLen + 1];
base::span(raw).copy_prefix_from(
base::span(test::kMinChallengeMessage).first(kMinChallengeHeaderLen));
raw[kMinChallengeHeaderLen] = 'Z';
ASSERT_NE(0x02, raw[12]);
ASSERT_EQ(0x00, raw[13]);
ASSERT_NE(0x02, raw[14]);
ASSERT_EQ(0x00, raw[15]);
raw[12] = 0x02;
raw[14] = 0x02;
NtlmBufferWriter writer(kMinChallengeHeaderLen + 1);
ASSERT_TRUE(writer.WriteBytes(raw));
ASSERT_FALSE(GetAuthMsgResult(client, writer));
}
TEST(NtlmClientTest, Type3UnicodeWithSessionSecuritySpecTest) {
NtlmClient client(NtlmFeatures(false));
std::vector<uint8_t> result = GenerateAuthMsg(client, test::kChallengeMsgV1);
ASSERT_FALSE(result.empty());
ASSERT_EQ(std::size(test::kExpectedAuthenticateMsgSpecResponseV1),
result.size());
ASSERT_EQ(base::span(test::kExpectedAuthenticateMsgSpecResponseV1),
base::span(result));
}
TEST(NtlmClientTest, Type3WithoutUnicode) {
NtlmClient client(NtlmFeatures(false));
std::vector<uint8_t> result =
GenerateAuthMsg(client, base::span(test::kMinChallengeMessageNoUnicode)
.first<kMinChallengeHeaderLen>());
ASSERT_FALSE(result.empty());
NtlmBufferReader reader(result);
ASSERT_TRUE(reader.MatchMessageHeader(MessageType::kAuthenticate));
uint8_t actual_lm_response[kResponseLenV1];
uint8_t actual_ntlm_response[kResponseLenV1];
ASSERT_TRUE(ReadBytesPayload(&reader, actual_lm_response));
ASSERT_TRUE(ReadBytesPayload(&reader, actual_ntlm_response));
ASSERT_EQ(base::span(test::kExpectedLmResponseWithV1SS),
base::span(actual_lm_response));
ASSERT_EQ(base::span(test::kExpectedNtlmResponseWithV1SS),
base::span(actual_ntlm_response));
std::string domain;
std::string username;
std::string hostname;
ASSERT_TRUE(ReadStringPayload(&reader, &domain));
ASSERT_EQ(test::kNtlmDomainAscii, domain);
ASSERT_TRUE(ReadStringPayload(&reader, &username));
ASSERT_EQ(test::kUserAscii, username);
ASSERT_TRUE(ReadStringPayload(&reader, &hostname));
ASSERT_EQ(test::kHostnameAscii, hostname);
ASSERT_TRUE(reader.MatchEmptySecurityBuffer());
NegotiateFlags flags;
ASSERT_TRUE(reader.ReadFlags(&flags));
ASSERT_EQ(NegotiateFlags::kNone, flags & NegotiateFlags::kUnicode);
ASSERT_EQ(NegotiateFlags::kOem, flags & NegotiateFlags::kOem);
}
TEST(NtlmClientTest, ClientDoesNotDowngradeSessionSecurity) {
NtlmClient client(NtlmFeatures(false));
std::vector<uint8_t> result =
GenerateAuthMsg(client, base::span(test::kMinChallengeMessageNoSS)
.first<kMinChallengeHeaderLen>());
ASSERT_FALSE(result.empty());
NtlmBufferReader reader(result);
ASSERT_TRUE(reader.MatchMessageHeader(MessageType::kAuthenticate));
uint8_t actual_lm_response[kResponseLenV1];
uint8_t actual_ntlm_response[kResponseLenV1];
ASSERT_TRUE(ReadBytesPayload(&reader, actual_lm_response));
ASSERT_TRUE(ReadBytesPayload(&reader, actual_ntlm_response));
ASSERT_EQ(base::span(test::kExpectedLmResponseWithV1SS),
base::span(actual_lm_response));
ASSERT_EQ(base::span(test::kExpectedNtlmResponseWithV1SS),
base::span(actual_ntlm_response));
std::u16string domain;
std::u16string username;
std::u16string hostname;
ASSERT_TRUE(ReadString16Payload(&reader, &domain));
ASSERT_EQ(test::kNtlmDomain, domain);
ASSERT_TRUE(ReadString16Payload(&reader, &username));
ASSERT_EQ(test::kUser, username);
ASSERT_TRUE(ReadString16Payload(&reader, &hostname));
ASSERT_EQ(test::kHostname, hostname);
ASSERT_TRUE(reader.MatchEmptySecurityBuffer());
NegotiateFlags flags;
ASSERT_TRUE(reader.ReadFlags(&flags));
ASSERT_EQ(NegotiateFlags::kUnicode, flags & NegotiateFlags::kUnicode);
ASSERT_EQ(NegotiateFlags::kExtendedSessionSecurity,
flags & NegotiateFlags::kExtendedSessionSecurity);
}
TEST(NtlmClientTest, SimpleConstructionV2) {
NtlmClient client(NtlmFeatures(true));
ASSERT_TRUE(client.IsNtlmV2());
ASSERT_TRUE(client.IsEpaEnabled());
ASSERT_TRUE(client.IsMicEnabled());
}
TEST(NtlmClientTest, VerifyNegotiateMessageV2) {
NtlmClient client(NtlmFeatures(true));
std::vector<uint8_t> result = client.GetNegotiateMessage();
ASSERT_FALSE(result.empty());
ASSERT_EQ(std::size(test::kExpectedNegotiateMsg), result.size());
ASSERT_EQ(base::span(test::kExpectedNegotiateMsg), base::span(result));
}
TEST(NtlmClientTest, VerifyAuthenticateMessageV2) {
NtlmClient client(NtlmFeatures(true));
std::vector<uint8_t> result =
GenerateAuthMsg(client, test::kChallengeMsgFromSpecV2);
ASSERT_FALSE(result.empty());
ASSERT_EQ(std::size(test::kExpectedAuthenticateMsgSpecResponseV2),
result.size());
ASSERT_EQ(base::span(test::kExpectedAuthenticateMsgSpecResponseV2),
base::span(result));
}
TEST(NtlmClientTest,
VerifyAuthenticateMessageInResponseToChallengeWithoutTargetInfoV2) {
NtlmClient client(NtlmFeatures(true));
std::vector<uint8_t> result = GenerateAuthMsg(client, test::kChallengeMsgV1);
ASSERT_FALSE(result.empty());
ASSERT_EQ(std::size(test::kExpectedAuthenticateMsgToOldV1ChallegeV2),
result.size());
ASSERT_EQ(base::span(test::kExpectedAuthenticateMsgToOldV1ChallegeV2),
base::span(result));
}
TEST(NtlmClientTest, AvPairsOverflow) {
{
NtlmClient client(NtlmFeatures(true));
std::vector<uint8_t> short_challenge;
ASSERT_NO_FATAL_FAILURE(MakeV2ChallengeMessage(0xfff, &short_challenge));
EXPECT_FALSE(GenerateAuthMsg(client, short_challenge).empty());
}
{
NtlmClient client(NtlmFeatures(true));
std::vector<uint8_t> long_challenge;
ASSERT_NO_FATAL_FAILURE(MakeV2ChallengeMessage(0xffff, &long_challenge));
EXPECT_TRUE(GenerateAuthMsg(client, long_challenge).empty());
}
}
}