#include "base/win/iat_patch_function.h"
#include "base/check_op.h"
#include "base/memory/raw_ptr_exclusion.h"
#include "base/notreached.h"
#include "base/win/patch_util.h"
#include "base/win/pe_image.h"
namespace base {
namespace win {
namespace {
struct InterceptFunctionInformation {
bool finished_operation;
const char* imported_from_module;
const char* function_name;
RAW_PTR_EXCLUSION void* new_function;
RAW_PTR_EXCLUSION void** old_function;
RAW_PTR_EXCLUSION IMAGE_THUNK_DATA** iat_thunk;
DWORD return_code;
};
void* GetIATFunction(IMAGE_THUNK_DATA* iat_thunk) {
if (!iat_thunk) {
NOTREACHED();
}
union FunctionThunk {
IMAGE_THUNK_DATA thunk;
RAW_PTR_EXCLUSION void* pointer;
} iat_function;
iat_function.thunk = *iat_thunk;
return iat_function.pointer;
}
bool InterceptEnumCallback(const base::win::PEImage& image,
const char* module,
DWORD ordinal,
const char* name,
DWORD hint,
IMAGE_THUNK_DATA* iat,
void* cookie) {
InterceptFunctionInformation* intercept_information =
reinterpret_cast<InterceptFunctionInformation*>(cookie);
if (!intercept_information) {
NOTREACHED();
}
DCHECK(module);
if (name && (0 == lstrcmpiA(name, intercept_information->function_name))) {
if (intercept_information->old_function) {
*(intercept_information->old_function) = GetIATFunction(iat);
}
if (intercept_information->iat_thunk) {
*(intercept_information->iat_thunk) = iat;
}
static_assert(
sizeof(iat->u1.Function) == sizeof(intercept_information->new_function),
"unknown IAT thunk format");
intercept_information->return_code = internal::ModifyCode(
&(iat->u1.Function), &(intercept_information->new_function),
sizeof(intercept_information->new_function));
intercept_information->finished_operation = true;
return false;
}
return true;
}
DWORD InterceptImportedFunction(HMODULE module_handle,
const char* imported_from_module,
const char* function_name,
void* new_function,
void** old_function,
IMAGE_THUNK_DATA** iat_thunk) {
if (!module_handle || !imported_from_module || !function_name ||
!new_function) {
NOTREACHED();
}
base::win::PEImage target_image(module_handle);
if (!target_image.VerifyMagic()) {
NOTREACHED();
}
InterceptFunctionInformation intercept_information = {false,
imported_from_module,
function_name,
new_function,
old_function,
iat_thunk,
ERROR_GEN_FAILURE};
target_image.EnumAllImports(InterceptEnumCallback, &intercept_information,
imported_from_module);
if (!intercept_information.finished_operation) {
target_image.EnumAllDelayImports(
InterceptEnumCallback, &intercept_information, imported_from_module);
}
return intercept_information.return_code;
}
DWORD RestoreImportedFunction(void* intercept_function,
void* original_function,
IMAGE_THUNK_DATA* iat_thunk) {
if (!intercept_function || !original_function || !iat_thunk) {
NOTREACHED();
}
if (GetIATFunction(iat_thunk) != intercept_function) {
NOTREACHED();
}
return internal::ModifyCode(&(iat_thunk->u1.Function), &original_function,
sizeof(original_function));
}
}
IATPatchFunction::IATPatchFunction() = default;
IATPatchFunction::~IATPatchFunction() {
if (intercept_function_) {
DWORD error = Unpatch();
DCHECK_EQ(static_cast<DWORD>(NO_ERROR), error);
}
}
DWORD IATPatchFunction::Patch(const wchar_t* module,
const char* imported_from_module,
const char* function_name,
void* new_function) {
HMODULE module_handle = LoadLibraryW(module);
if (!module_handle) {
NOTREACHED();
}
DWORD error = PatchFromModule(module_handle, imported_from_module,
function_name, new_function);
if (NO_ERROR == error) {
module_handle_ = module_handle;
} else {
FreeLibrary(module_handle);
}
return error;
}
DWORD IATPatchFunction::PatchFromModule(HMODULE module,
const char* imported_from_module,
const char* function_name,
void* new_function) {
DCHECK_EQ(nullptr, original_function_);
DCHECK_EQ(nullptr, iat_thunk_);
DCHECK_EQ(nullptr, intercept_function_);
DCHECK(module);
DWORD error = InterceptImportedFunction(
module, imported_from_module, function_name, new_function,
&original_function_.AsEphemeralRawAddr(),
&iat_thunk_.AsEphemeralRawAddr());
if (NO_ERROR == error) {
DCHECK_NE(original_function_, intercept_function_);
intercept_function_ = new_function;
}
return error;
}
DWORD IATPatchFunction::Unpatch() {
DWORD error = RestoreImportedFunction(intercept_function_, original_function_,
iat_thunk_);
DCHECK_EQ(static_cast<DWORD>(NO_ERROR), error);
if (module_handle_) {
FreeLibrary(module_handle_);
}
module_handle_ = nullptr;
intercept_function_ = nullptr;
original_function_ = nullptr;
iat_thunk_ = nullptr;
return error;
}
void* IATPatchFunction::original_function() const {
DCHECK(is_patched());
return original_function_;
}
}
}