#include "chrome/chrome_elf/third_party_dlls/main.h"
#include <windows.h>
#include <string>
#include "base/command_line.h"
#include "base/files/file_util.h"
#include "base/files/scoped_temp_dir.h"
#include "base/hash/sha1.h"
#include "base/path_service.h"
#include "base/process/launch.h"
#include "base/scoped_native_library.h"
#include "base/strings/string_number_conversions.h"
#include "base/strings/utf_string_conversions.h"
#include "base/test/test_reg_util_win.h"
#include "base/test/test_timeouts.h"
#include "build/build_config.h"
#include "chrome/chrome_elf/nt_registry/nt_registry.h"
#include "chrome/chrome_elf/sha1/sha1.h"
#include "chrome/chrome_elf/third_party_dlls/hook.h"
#include "chrome/chrome_elf/third_party_dlls/main_unittest_exe.h"
#include "chrome/chrome_elf/third_party_dlls/packed_list_file.h"
#include "chrome/chrome_elf/third_party_dlls/packed_list_format.h"
#include "chrome/install_static/install_util.h"
#include "testing/gtest/include/gtest/gtest.h"
namespace third_party_dlls {
namespace {
constexpr wchar_t kTestExeFilename[] = L"third_party_dlls_test_exe.exe";
constexpr wchar_t kTestBlFileName[] = L"blfile";
constexpr wchar_t kTestDllName1[] = L"main_unittest_dll_1.dll";
constexpr wchar_t kTestDllName1MixedCase[] = L"MaiN_uniTtest_dLL_1.Dll";
constexpr wchar_t kTestDllName2[] = L"main_unittest_dll_2.dll";
constexpr wchar_t kChineseUnicode[] = {0x68D5, 0x8272, 0x72D0, 0x72F8, 0x002E,
0x0064, 0x006C, 0x006C, 0x0000};
constexpr wchar_t kOldBlocklistDllName[] = L"libapi2hook.dll";
struct TestModuleData {
std::string image_name;
std::string section_path;
std::string section_basename;
DWORD timedatestamp;
DWORD imagesize;
};
base::TimeDelta g_timeout =
::IsDebuggerPresent() ? base::TimeDelta::Max() : base::Milliseconds(5000);
void LaunchChildAndWait(const base::CommandLine& command_line, int* exit_code) {
base::Process proc =
base::LaunchProcess(command_line, base::LaunchOptionsForTest());
ASSERT_TRUE(proc.IsValid());
*exit_code = 0;
if (!proc.WaitForExitWithTimeout(g_timeout, exit_code)) {
proc.Terminate(1, false);
ADD_FAILURE();
}
return;
}
bool GetTestModuleData(const std::wstring& file_name,
const std::wstring& file_path,
TestModuleData* test_module) {
base::FilePath path(file_path);
path = path.Append(file_name);
base::ScopedNativeLibrary test_dll(path);
if (!test_dll.is_valid())
return false;
return GetDataFromImageForTesting(
test_dll.get(), &test_module->timedatestamp, &test_module->imagesize,
&test_module->image_name, &test_module->section_path,
&test_module->section_basename);
}
PackedListModule GeneratePackedListModule(const std::string& image_name,
DWORD timedatestamp,
DWORD imagesize) {
assert(!image_name.empty());
PackedListModule packed_module;
packed_module.code_id_hash =
elf_sha1::SHA1HashString(GetFingerprintString(timedatestamp, imagesize));
packed_module.basename_hash = elf_sha1::SHA1HashString(image_name);
return packed_module;
}
inline std::wstring MakePath(const std::wstring& path,
const std::wstring& name) {
std::wstring full_path(path);
full_path.push_back(L'\\');
full_path.append(name);
return full_path;
}
inline bool MakeFileCopy(const std::wstring& old_path,
const std::wstring& old_name,
const std::wstring& new_path,
const std::wstring& new_name) {
base::FilePath source(MakePath(old_path, old_name));
base::FilePath destination(MakePath(new_path, new_name));
return base::CopyFileW(source, destination);
}
void RegRedirect(nt::ROOT_KEY key,
registry_util::RegistryOverrideManager* rom) {
ASSERT_NE(key, nt::AUTO);
HKEY root = (key == nt::HKCU ? HKEY_CURRENT_USER : HKEY_LOCAL_MACHINE);
std::wstring temp;
ASSERT_NO_FATAL_FAILURE(rom->OverrideRegistry(root, &temp));
ASSERT_TRUE(nt::SetTestingOverride(key, temp));
}
void CancelRegRedirect(nt::ROOT_KEY key) {
ASSERT_NE(key, nt::AUTO);
ASSERT_TRUE(nt::SetTestingOverride(key, std::wstring()));
}
bool QueryStatusCodes(std::vector<ThirdPartyStatus>* status_array) {
HANDLE handle = nullptr;
if (!nt::OpenRegKey(nt::HKCU,
install_static::GetRegistryPath()
.append(kThirdPartyRegKeyName)
.c_str(),
KEY_QUERY_VALUE, &handle, nullptr)) {
return false;
}
ULONG type = REG_NONE;
std::vector<uint8_t> temp_buffer;
bool success =
nt::QueryRegKeyValue(handle, kStatusCodesRegValue, &type, &temp_buffer);
nt::CloseRegKey(handle);
if (!success || type != REG_BINARY)
return false;
ConvertBufferToStatusCodes(temp_buffer, status_array);
return true;
}
class ThirdPartyTest : public testing::Test {
public:
ThirdPartyTest(const ThirdPartyTest&) = delete;
ThirdPartyTest& operator=(const ThirdPartyTest&) = delete;
protected:
ThirdPartyTest() = default;
void SetUp() override {
ASSERT_TRUE(scoped_temp_dir_.CreateUniqueTempDir());
base::FilePath path = scoped_temp_dir_.GetPath();
path = path.Append(kTestBlFileName);
bl_test_file_path_ = std::move(path.value());
base::FilePath exe;
ASSERT_TRUE(base::PathService::Get(base::DIR_EXE, &exe));
exe_dir_ = std::move(exe.value());
base::File file(base::FilePath(bl_test_file_path_),
base::File::FLAG_CREATE_ALWAYS | base::File::FLAG_WRITE |
base::File::FLAG_WIN_SHARE_DELETE |
base::File::FLAG_DELETE_ON_CLOSE);
ASSERT_TRUE(file.IsValid());
bl_file_ = std::move(file);
}
void TearDown() override {}
bool WriteModulesToBlocklist(const std::vector<PackedListModule>& list) {
bl_file_.SetLength(0);
PackedListMetadata meta = {kInitialVersion,
static_cast<uint32_t>(list.size())};
if (bl_file_.Write(0, reinterpret_cast<const char*>(&meta), sizeof(meta)) !=
static_cast<int>(sizeof(meta))) {
return false;
}
int size = static_cast<int>(list.size() * sizeof(PackedListModule));
if (bl_file_.Write(sizeof(PackedListMetadata),
reinterpret_cast<const char*>(list.data()),
size) != size) {
return false;
}
return true;
}
const std::wstring& GetBlTestFilePath() { return bl_test_file_path_; }
const std::wstring& GetExeDir() { return exe_dir_; }
const std::wstring& GetScopedTempDirValue() {
return scoped_temp_dir_.GetPath().value();
}
private:
base::ScopedTempDir scoped_temp_dir_;
base::File bl_file_;
std::wstring bl_test_file_path_;
std::wstring exe_dir_;
};
#if BUILDFLAG(IS_WIN)
#define MAYBE_Base DISABLED_Base
#else
#define MAYBE_Base Base
#endif
TEST_F(ThirdPartyTest, MAYBE_Base) {
base::CommandLine cmd_line1 = base::CommandLine::FromString(kTestExeFilename);
cmd_line1.AppendArgNative(GetBlTestFilePath());
cmd_line1.AppendArgNative(base::NumberToWString(kTestOnlyInitialization));
int exit_code = 0;
LaunchChildAndWait(cmd_line1, &exit_code);
ASSERT_EQ(kDllLoadSuccess, exit_code);
base::CommandLine cmd_line2 = base::CommandLine::FromString(kTestExeFilename);
cmd_line2.AppendArgNative(GetBlTestFilePath());
cmd_line2.AppendArgNative(base::NumberToWString(kTestSingleDllLoad));
cmd_line2.AppendArgNative(MakePath(GetExeDir(), kTestDllName1));
LaunchChildAndWait(cmd_line2, &exit_code);
ASSERT_EQ(kDllLoadSuccess, exit_code);
TestModuleData module_data = {};
ASSERT_TRUE(GetTestModuleData(kTestDllName1, GetExeDir(), &module_data));
EXPECT_TRUE(module_data.image_name.empty());
std::vector<PackedListModule> vector(1);
vector.emplace_back(GeneratePackedListModule(module_data.section_basename,
module_data.timedatestamp,
module_data.imagesize));
ASSERT_TRUE(WriteModulesToBlocklist(vector));
base::CommandLine cmd_line3 = base::CommandLine::FromString(kTestExeFilename);
cmd_line3.AppendArgNative(GetBlTestFilePath());
cmd_line3.AppendArgNative(base::NumberToWString(kTestSingleDllLoad));
cmd_line3.AppendArgNative(MakePath(GetExeDir(), kTestDllName1));
LaunchChildAndWait(cmd_line3, &exit_code);
ASSERT_EQ(kDllLoadFailed, exit_code);
ASSERT_TRUE(MakeFileCopy(GetExeDir(), kTestDllName1, GetScopedTempDirValue(),
kTestDllName1MixedCase));
base::CommandLine cmd_line4 = base::CommandLine::FromString(kTestExeFilename);
cmd_line4.AppendArgNative(GetBlTestFilePath());
cmd_line4.AppendArgNative(base::NumberToWString(kTestSingleDllLoad));
cmd_line4.AppendArgNative(
MakePath(GetScopedTempDirValue(), kTestDllName1MixedCase));
LaunchChildAndWait(cmd_line4, &exit_code);
ASSERT_EQ(kDllLoadFailed, exit_code);
}
TEST_F(ThirdPartyTest, WideCharEncoding) {
ASSERT_TRUE(MakeFileCopy(GetExeDir(), kTestDllName1, GetScopedTempDirValue(),
kChineseUnicode));
base::CommandLine cmd_line1 = base::CommandLine::FromString(kTestExeFilename);
cmd_line1.AppendArgNative(GetBlTestFilePath());
cmd_line1.AppendArgNative(base::NumberToWString(kTestSingleDllLoad));
cmd_line1.AppendArgNative(MakePath(GetScopedTempDirValue(), kChineseUnicode));
int exit_code = 0;
LaunchChildAndWait(cmd_line1, &exit_code);
ASSERT_EQ(kDllLoadSuccess, exit_code);
TestModuleData module_data = {};
ASSERT_TRUE(GetTestModuleData(kChineseUnicode, GetScopedTempDirValue(),
&module_data));
EXPECT_TRUE(module_data.image_name.empty());
std::vector<PackedListModule> vector;
vector.emplace_back(GeneratePackedListModule(module_data.section_basename,
module_data.timedatestamp,
module_data.imagesize));
ASSERT_TRUE(WriteModulesToBlocklist(vector));
base::CommandLine cmd_line2 = base::CommandLine::FromString(kTestExeFilename);
cmd_line2.AppendArgNative(GetBlTestFilePath());
cmd_line2.AppendArgNative(base::NumberToWString(kTestSingleDllLoad));
cmd_line2.AppendArgNative(MakePath(GetScopedTempDirValue(), kChineseUnicode));
LaunchChildAndWait(cmd_line2, &exit_code);
ASSERT_EQ(kDllLoadFailed, exit_code);
}
TEST_F(ThirdPartyTest, WideCharEncodingWithExportDir) {
ASSERT_TRUE(MakeFileCopy(GetExeDir(), kTestDllName2, GetScopedTempDirValue(),
kChineseUnicode));
base::CommandLine cmd_line1 = base::CommandLine::FromString(kTestExeFilename);
cmd_line1.AppendArgNative(GetBlTestFilePath());
cmd_line1.AppendArgNative(base::NumberToWString(kTestSingleDllLoad));
cmd_line1.AppendArgNative(MakePath(GetScopedTempDirValue(), kChineseUnicode));
int exit_code = 0;
LaunchChildAndWait(cmd_line1, &exit_code);
ASSERT_EQ(kDllLoadSuccess, exit_code);
TestModuleData module_data = {};
ASSERT_TRUE(GetTestModuleData(kChineseUnicode, GetScopedTempDirValue(),
&module_data));
EXPECT_FALSE(module_data.image_name.empty());
std::vector<PackedListModule> vector;
vector.emplace_back(GeneratePackedListModule(base::WideToASCII(kTestDllName2),
module_data.timedatestamp,
module_data.imagesize));
ASSERT_TRUE(WriteModulesToBlocklist(vector));
base::CommandLine cmd_line2 = base::CommandLine::FromString(kTestExeFilename);
cmd_line2.AppendArgNative(GetBlTestFilePath());
cmd_line2.AppendArgNative(base::NumberToWString(kTestSingleDllLoad));
cmd_line2.AppendArgNative(MakePath(GetScopedTempDirValue(), kChineseUnicode));
LaunchChildAndWait(cmd_line2, &exit_code);
ASSERT_EQ(kDllLoadFailed, exit_code);
vector.clear();
vector.emplace_back(GeneratePackedListModule(
base::WideToUTF8(kChineseUnicode), module_data.timedatestamp,
module_data.imagesize));
ASSERT_TRUE(WriteModulesToBlocklist(vector));
base::CommandLine cmd_line3 = base::CommandLine::FromString(kTestExeFilename);
cmd_line3.AppendArgNative(GetBlTestFilePath());
cmd_line3.AppendArgNative(base::NumberToWString(kTestSingleDllLoad));
cmd_line3.AppendArgNative(MakePath(GetScopedTempDirValue(), kChineseUnicode));
LaunchChildAndWait(cmd_line3, &exit_code);
ASSERT_EQ(kDllLoadFailed, exit_code);
}
TEST_F(ThirdPartyTest, DeprecatedBlocklistSanityCheck) {
ASSERT_TRUE(MakeFileCopy(GetExeDir(), kTestDllName1, GetScopedTempDirValue(),
kOldBlocklistDllName));
base::CommandLine cmd_line1 = base::CommandLine::FromString(kTestExeFilename);
cmd_line1.AppendArgNative(GetBlTestFilePath());
cmd_line1.AppendArgNative(base::NumberToWString(kTestSingleDllLoad));
cmd_line1.AppendArgNative(
MakePath(GetScopedTempDirValue(), kOldBlocklistDllName));
int exit_code = 0;
LaunchChildAndWait(cmd_line1, &exit_code);
ASSERT_EQ(kDllLoadFailed, exit_code);
}
TEST_F(ThirdPartyTest, SHA1SanityCheck) {
ASSERT_TRUE(MakeFileCopy(GetExeDir(), kTestDllName1, GetScopedTempDirValue(),
kChineseUnicode));
TestModuleData module_data = {};
ASSERT_TRUE(GetTestModuleData(kChineseUnicode, GetScopedTempDirValue(),
&module_data));
PackedListModule elf_sha1_generated = GeneratePackedListModule(
base::WideToUTF8(kChineseUnicode), module_data.timedatestamp,
module_data.imagesize);
const std::string module_basename_hash =
base::SHA1HashString(base::WideToUTF8(kChineseUnicode));
const std::string module_code_id_hash = base::SHA1HashString(
GetFingerprintString(module_data.timedatestamp, module_data.imagesize));
EXPECT_EQ(::memcmp(&elf_sha1_generated.basename_hash[0],
module_basename_hash.data(), elf_sha1::kSHA1Length),
0);
EXPECT_EQ(::memcmp(&elf_sha1_generated.code_id_hash[0],
module_code_id_hash.data(), elf_sha1::kSHA1Length),
0);
}
#if BUILDFLAG(IS_WIN)
#define MAYBE_PathCaseSensitive DISABLED_PathCaseSensitive
#else
#define MAYBE_PathCaseSensitive PathCaseSensitive
#endif
TEST_F(ThirdPartyTest, MAYBE_PathCaseSensitive) {
ASSERT_TRUE(MakeFileCopy(GetExeDir(), kTestDllName2, GetScopedTempDirValue(),
kTestDllName1MixedCase));
TestModuleData module_data = {};
ASSERT_TRUE(GetTestModuleData(kTestDllName1MixedCase, GetScopedTempDirValue(),
&module_data));
base::FilePath drive;
ASSERT_TRUE(base::DevicePathToDriveLetterPath(
base::FilePath(base::ASCIIToWide(module_data.section_path)), &drive));
EXPECT_EQ(drive.value().compare(
MakePath(GetScopedTempDirValue(), kTestDllName1MixedCase)),
0);
base::CommandLine cmd_line1 = base::CommandLine::FromString(kTestExeFilename);
cmd_line1.AppendArgNative(GetBlTestFilePath());
cmd_line1.AppendArgNative(base::NumberToWString(kTestSingleDllLoad));
cmd_line1.AppendArgNative(
MakePath(GetScopedTempDirValue(), kTestDllName1MixedCase));
int exit_code = 0;
LaunchChildAndWait(cmd_line1, &exit_code);
ASSERT_EQ(kDllLoadSuccess, exit_code);
}
TEST_F(ThirdPartyTest, StatusCodes) {
registry_util::RegistryOverrideManager override_manager;
ASSERT_NO_FATAL_FAILURE(RegRedirect(nt::HKCU, &override_manager));
ASSERT_TRUE(ResetStatusCodesForTesting());
std::vector<ThirdPartyStatus> code_array;
EXPECT_TRUE(QueryStatusCodes(&code_array));
EXPECT_EQ(0u, code_array.size());
ASSERT_NO_FATAL_FAILURE(
AddStatusCodeForTesting(ThirdPartyStatus::kFileEmpty));
ASSERT_NO_FATAL_FAILURE(
AddStatusCodeForTesting(ThirdPartyStatus::kLogsCreateMutexFailure));
ASSERT_NO_FATAL_FAILURE(
AddStatusCodeForTesting(ThirdPartyStatus::kHookVirtualProtectFailure));
EXPECT_TRUE(QueryStatusCodes(&code_array));
ASSERT_EQ(3u, code_array.size());
EXPECT_EQ(ThirdPartyStatus::kFileEmpty, code_array[0]);
EXPECT_EQ(ThirdPartyStatus::kLogsCreateMutexFailure, code_array[1]);
EXPECT_EQ(ThirdPartyStatus::kHookVirtualProtectFailure, code_array[2]);
EXPECT_TRUE(ResetStatusCodesForTesting());
EXPECT_TRUE(QueryStatusCodes(&code_array));
EXPECT_EQ(0u, code_array.size());
ASSERT_NO_FATAL_FAILURE(CancelRegRedirect(nt::HKCU));
}
}
}