#include "sandbox/policy/win/hook_util/hook_util.h"
#include <windows.h>
#include "chrome/chrome_elf/hook_util/test/hook_util_test_dll.h"
#include "testing/gtest/include/gtest/gtest.h"
namespace {
const char kIATTestDllName[] = "hook_util_test_dll.dll";
const char kIATExportedApiFunction[] = "ExportedApi";
void IATHookedExportedApi() {
return;
}
void IATHookedExportedApiTwo() {
printf("Something to make this function different!\n");
return;
}
class HookTest : public testing::Test {
protected:
HookTest() = default;
};
TEST_F(HookTest, IATHook) {
ASSERT_EQ(0, ExportedApiCallCount());
ExportedApi();
ExportedApi();
ASSERT_EQ(2, ExportedApiCallCount());
sandbox::policy::IATHook iat_hook;
if (iat_hook.Hook(
::GetModuleHandle(nullptr), kIATTestDllName, kIATExportedApiFunction,
reinterpret_cast<void*>(IATHookedExportedApi)) != NO_ERROR) {
ADD_FAILURE();
return;
}
if (iat_hook.Hook(::GetModuleHandle(nullptr), kIATTestDllName,
kIATExportedApiFunction,
reinterpret_cast<void*>(IATHookedExportedApi)) !=
ERROR_SHARING_VIOLATION)
ADD_FAILURE();
ExportedApi();
ExportedApi();
ExportedApi();
EXPECT_EQ(2, ExportedApiCallCount());
if (iat_hook.Unhook() != NO_ERROR)
ADD_FAILURE();
ExportedApi();
EXPECT_EQ(3, ExportedApiCallCount());
if (iat_hook.Unhook() != ERROR_INVALID_PARAMETER)
ADD_FAILURE();
if (iat_hook.Hook(::GetModuleHandle(nullptr), kIATTestDllName, "FooBarred",
reinterpret_cast<void*>(IATHookedExportedApi)) !=
ERROR_PROC_NOT_FOUND)
ADD_FAILURE();
if (iat_hook.Hook(
::GetModuleHandle(nullptr), kIATTestDllName, kIATExportedApiFunction,
reinterpret_cast<void*>(IATHookedExportedApi)) != NO_ERROR) {
ADD_FAILURE();
return;
}
sandbox::policy::IATHook shady_third_party_iat_hook;
if (shady_third_party_iat_hook.Hook(
::GetModuleHandle(nullptr), kIATTestDllName, kIATExportedApiFunction,
reinterpret_cast<void*>(IATHookedExportedApiTwo)) != NO_ERROR)
ADD_FAILURE();
if (iat_hook.Unhook() != ERROR_INVALID_FUNCTION)
ADD_FAILURE();
if (shady_third_party_iat_hook.Unhook() != NO_ERROR)
ADD_FAILURE();
}
}