#include "base/native_library.h"
#include <windows.h>
#include "base/files/file_util.h"
#include "base/metrics/histogram_macros.h"
#include "base/path_service.h"
#include "base/scoped_native_library.h"
#include "base/strings/strcat.h"
#include "base/strings/string_piece.h"
#include "base/strings/string_util.h"
#include "base/strings/stringprintf.h"
#include "base/strings/utf_string_conversions.h"
#include "base/threading/scoped_blocking_call.h"
#include "base/threading/scoped_thread_priority.h"
namespace base {
namespace {
enum LoadLibraryResult {
SUCCEED = 0,
FAIL_AND_SUCCEED,
FAIL_AND_FAIL,
UNAVAILABLE_AND_SUCCEED_OBSOLETE,
UNAVAILABLE_AND_FAIL_OBSOLETE,
END
};
void LogLibrarayLoadResultToUMA(LoadLibraryResult result) {
UMA_HISTOGRAM_ENUMERATION("LibraryLoader.LoadNativeLibraryWindows", result,
LoadLibraryResult::END);
}
LoadLibraryResult GetLoadLibraryResult(bool has_load_library_succeeded) {
return has_load_library_succeeded ? LoadLibraryResult::FAIL_AND_SUCCEED
: LoadLibraryResult::FAIL_AND_FAIL;
}
NativeLibrary LoadNativeLibraryHelper(const FilePath& library_path,
NativeLibraryLoadError* error) {
ScopedBlockingCall scoped_blocking_call(FROM_HERE, BlockingType::MAY_BLOCK);
SCOPED_MAY_LOAD_LIBRARY_AT_BACKGROUND_PRIORITY_REPEATEDLY();
HMODULE module_handle = nullptr;
LoadLibraryResult load_library_result = LoadLibraryResult::SUCCEED;
module_handle = ::LoadLibraryExW(
library_path.value().c_str(), nullptr,
LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR | LOAD_LIBRARY_SEARCH_DEFAULT_DIRS);
if (module_handle) {
LogLibrarayLoadResultToUMA(load_library_result);
return module_handle;
}
if (error) {
error->code = ::GetLastError();
}
bool restore_directory = false;
FilePath current_directory;
if (GetCurrentDirectory(¤t_directory)) {
FilePath plugin_path = library_path.DirName();
if (!plugin_path.empty()) {
SetCurrentDirectory(plugin_path);
restore_directory = true;
}
}
module_handle = ::LoadLibraryW(library_path.value().c_str());
if (!module_handle && error) {
error->code = ::GetLastError();
}
if (restore_directory)
SetCurrentDirectory(current_directory);
LogLibrarayLoadResultToUMA(GetLoadLibraryResult(!!module_handle));
return module_handle;
}
NativeLibrary LoadSystemLibraryHelper(const FilePath& library_path,
NativeLibraryLoadError* error) {
ScopedBlockingCall scoped_blocking_call(FROM_HERE, BlockingType::MAY_BLOCK);
NativeLibrary module;
BOOL module_found =
::GetModuleHandleExW(0, library_path.value().c_str(), &module);
if (!module_found) {
module = ::LoadLibraryExW(library_path.value().c_str(), nullptr,
LOAD_LIBRARY_SEARCH_SYSTEM32);
if (!module && error)
error->code = ::GetLastError();
LogLibrarayLoadResultToUMA(GetLoadLibraryResult(!!module));
}
return module;
}
FilePath GetSystemLibraryName(FilePath::StringPieceType name) {
FilePath library_path;
if (PathService::Get(DIR_SYSTEM, &library_path))
library_path = library_path.Append(name);
return library_path;
}
}
std::string NativeLibraryLoadError::ToString() const {
return StringPrintf("%lu", code);
}
NativeLibrary LoadNativeLibraryWithOptions(const FilePath& library_path,
const NativeLibraryOptions& options,
NativeLibraryLoadError* error) {
return LoadNativeLibraryHelper(library_path, error);
}
void UnloadNativeLibrary(NativeLibrary library) {
FreeLibrary(library);
}
void* GetFunctionPointerFromNativeLibrary(NativeLibrary library,
StringPiece name) {
return reinterpret_cast<void*>(GetProcAddress(library, name.data()));
}
std::string GetNativeLibraryName(StringPiece name) {
DCHECK(IsStringASCII(name));
return StrCat({name, ".dll"});
}
std::string GetLoadableModuleName(StringPiece name) {
return GetNativeLibraryName(name);
}
NativeLibrary LoadSystemLibrary(FilePath::StringPieceType name,
NativeLibraryLoadError* error) {
FilePath library_path = GetSystemLibraryName(name);
if (library_path.empty()) {
if (error)
error->code = ERROR_NOT_FOUND;
return nullptr;
}
return LoadSystemLibraryHelper(library_path, error);
}
NativeLibrary PinSystemLibrary(FilePath::StringPieceType name,
NativeLibraryLoadError* error) {
FilePath library_path = GetSystemLibraryName(name);
if (library_path.empty()) {
if (error)
error->code = ERROR_NOT_FOUND;
return nullptr;
}
ScopedBlockingCall scoped_blocking_call(FROM_HERE, BlockingType::MAY_BLOCK);
ScopedNativeLibrary module;
if (::GetModuleHandleExW(GET_MODULE_HANDLE_EX_FLAG_PIN,
library_path.value().c_str(),
ScopedNativeLibrary::Receiver(module).get())) {
return module.release();
}
module = ScopedNativeLibrary(LoadSystemLibraryHelper(library_path, error));
if (!module.is_valid())
return nullptr;
ScopedNativeLibrary temp;
if (::GetModuleHandleExW(GET_MODULE_HANDLE_EX_FLAG_PIN,
library_path.value().c_str(),
ScopedNativeLibrary::Receiver(temp).get())) {
return module.release();
}
if (error)
error->code = ::GetLastError();
return nullptr;
}
}