#include "chrome/common/conflicts/module_watcher_win.h"
#include <windows.h>
#include <winternl.h>
#include <tlhelp32.h>
#include <string>
#include <string_view>
#include <utility>
#include "base/functional/bind.h"
#include "base/location.h"
#include "base/memory/ptr_util.h"
#include "base/no_destructor.h"
#include "base/strings/utf_string_conversions.h"
#include "base/synchronization/lock.h"
#include "base/task/sequenced_task_runner.h"
#include "base/task/task_traits.h"
#include "base/task/thread_pool.h"
#include "base/win/scoped_handle.h"
enum {
LDR_DLL_NOTIFICATION_REASON_LOADED = 1,
LDR_DLL_NOTIFICATION_REASON_UNLOADED = 2,
};
struct LDR_DLL_LOADED_NOTIFICATION_DATA {
ULONG Flags;
PCUNICODE_STRING FullDllName;
PCUNICODE_STRING BaseDllName;
PVOID DllBase;
ULONG SizeOfImage;
};
using PLDR_DLL_LOADED_NOTIFICATION_DATA = LDR_DLL_LOADED_NOTIFICATION_DATA*;
struct LDR_DLL_UNLOADED_NOTIFICATION_DATA {
ULONG Flags;
PCUNICODE_STRING FullDllName;
PCUNICODE_STRING BaseDllName;
PVOID DllBase;
ULONG SizeOfImage;
};
using PLDR_DLL_UNLOADED_NOTIFICATION_DATA = LDR_DLL_UNLOADED_NOTIFICATION_DATA*;
union LDR_DLL_NOTIFICATION_DATA {
LDR_DLL_LOADED_NOTIFICATION_DATA Loaded;
LDR_DLL_UNLOADED_NOTIFICATION_DATA Unloaded;
};
using PLDR_DLL_NOTIFICATION_DATA = LDR_DLL_NOTIFICATION_DATA*;
using PLDR_DLL_NOTIFICATION_FUNCTION =
VOID(CALLBACK*)(ULONG notification_reason,
const LDR_DLL_NOTIFICATION_DATA* notification_data,
PVOID context);
using LdrRegisterDllNotificationFunc =
NTSTATUS(NTAPI*)(ULONG flags,
PLDR_DLL_NOTIFICATION_FUNCTION notification_function,
PVOID context,
PVOID* cookie);
using LdrUnregisterDllNotificationFunc = NTSTATUS(NTAPI*)(PVOID cookie);
namespace {
base::Lock& GetModuleWatcherLock() {
static base::NoDestructor<base::Lock> module_watcher_lock;
return *module_watcher_lock;
}
ModuleWatcher* g_module_watcher_instance = nullptr;
constexpr wchar_t kNtDll[] = L"ntdll.dll";
constexpr char kLdrRegisterDllNotification[] = "LdrRegisterDllNotification";
constexpr char kLdrUnregisterDllNotification[] = "LdrUnregisterDllNotification";
base::FilePath ToFilePath(const UNICODE_STRING* str) {
return base::FilePath(
std::wstring_view(str->Buffer, str->Length / sizeof(wchar_t)));
}
template <typename NotificationDataType>
void OnModuleEvent(ModuleWatcher::ModuleEventType event_type,
const NotificationDataType& notification_data,
const ModuleWatcher::OnModuleEventCallback& callback) {
ModuleWatcher::ModuleEvent event(
event_type, ToFilePath(notification_data.FullDllName),
notification_data.DllBase, notification_data.SizeOfImage);
callback.Run(event);
}
}
std::unique_ptr<ModuleWatcher> ModuleWatcher::Create(
OnModuleEventCallback callback) {
{
base::AutoLock lock(GetModuleWatcherLock());
if (g_module_watcher_instance)
return nullptr;
g_module_watcher_instance = new ModuleWatcher();
}
g_module_watcher_instance->Initialize(std::move(callback));
return base::WrapUnique(g_module_watcher_instance);
}
ModuleWatcher::~ModuleWatcher() {
UnregisterDllNotificationCallback();
base::AutoLock lock(GetModuleWatcherLock());
DCHECK_EQ(g_module_watcher_instance, this);
g_module_watcher_instance = nullptr;
}
ModuleWatcher::ModuleWatcher() = default;
void ModuleWatcher::Initialize(OnModuleEventCallback callback) {
callback_ = std::move(callback);
RegisterDllNotificationCallback();
base::ThreadPool::PostTask(
FROM_HERE,
{base::TaskPriority::BEST_EFFORT,
base::TaskShutdownBehavior::CONTINUE_ON_SHUTDOWN},
base::BindOnce(&ModuleWatcher::EnumerateAlreadyLoadedModules,
base::SequencedTaskRunner::GetCurrentDefault(),
base::BindRepeating(&ModuleWatcher::RunCallback,
weak_ptr_factory_.GetWeakPtr())));
}
void ModuleWatcher::RegisterDllNotificationCallback() {
LdrRegisterDllNotificationFunc reg_fn =
reinterpret_cast<LdrRegisterDllNotificationFunc>(::GetProcAddress(
::GetModuleHandle(kNtDll), kLdrRegisterDllNotification));
if (reg_fn)
reg_fn(0, &LoaderNotificationCallback, this,
&dll_notification_cookie_.AsEphemeralRawAddr());
}
void ModuleWatcher::UnregisterDllNotificationCallback() {
LdrUnregisterDllNotificationFunc unreg_fn =
reinterpret_cast<LdrUnregisterDllNotificationFunc>(::GetProcAddress(
::GetModuleHandle(kNtDll), kLdrUnregisterDllNotification));
if (unreg_fn)
unreg_fn(dll_notification_cookie_);
}
void ModuleWatcher::EnumerateAlreadyLoadedModules(
scoped_refptr<base::SequencedTaskRunner> task_runner,
OnModuleEventCallback callback) {
base::win::ScopedHandle snap;
DWORD process_id = ::GetCurrentProcessId();
for (int i = 0; i < 5; ++i) {
snap.Set(::CreateToolhelp32Snapshot(TH32CS_SNAPMODULE | TH32CS_SNAPMODULE32,
process_id));
if (snap.is_valid()) {
break;
}
if (::GetLastError() != ERROR_BAD_LENGTH)
return;
}
if (!snap.is_valid()) {
return;
}
MODULEENTRY32 module = {sizeof(module)};
for (BOOL result = ::Module32First(snap.Get(), &module); result != FALSE;
result = ::Module32Next(snap.Get(), &module)) {
ModuleEvent event(ModuleEventType::kModuleAlreadyLoaded,
base::FilePath(module.szExePath), module.modBaseAddr,
module.modBaseSize);
task_runner->PostTask(FROM_HERE, base::BindOnce(callback, event));
}
}
ModuleWatcher::OnModuleEventCallback ModuleWatcher::GetCallbackForContext(
void* context) {
base::AutoLock lock(GetModuleWatcherLock());
if (context != g_module_watcher_instance)
return OnModuleEventCallback();
return g_module_watcher_instance->callback_;
}
void __stdcall ModuleWatcher::LoaderNotificationCallback(
unsigned long notification_reason,
const LDR_DLL_NOTIFICATION_DATA* notification_data,
void* context) {
auto callback = GetCallbackForContext(context);
if (!callback)
return;
switch (notification_reason) {
case LDR_DLL_NOTIFICATION_REASON_LOADED:
OnModuleEvent(ModuleEventType::kModuleLoaded, notification_data->Loaded,
callback);
break;
case LDR_DLL_NOTIFICATION_REASON_UNLOADED:
break;
default:
NOTREACHED() << "Unknown LDR_DLL_NOTIFICATION_REASON: "
<< notification_reason;
}
}
void ModuleWatcher::RunCallback(const ModuleEvent& event) {
callback_.Run(event);
}