#include "services/webnn/public/cpp/platform_functions_win.h"
#include <winerror.h>
#include <wrl.h>
#include "base/metrics/histogram_functions.h"
#include "base/no_destructor.h"
#include "base/scoped_generic.h"
#include "base/strings/string_util_win.h"
#include "services/webnn/public/cpp/win_app_runtime_package_info.h"
namespace webnn {
namespace {
struct ScopedWcharTypeTraits {
static wchar_t* InvalidValue() { return nullptr; }
static void Free(wchar_t* value) {
if (value) {
::HeapFree(GetProcessHeap(), 0, value);
}
}
};
using ScopedWcharType = base::ScopedGeneric<wchar_t*, ScopedWcharTypeTraits>;
base::FilePath GetPackagePath(const wchar_t* package_full_name) {
uint32_t path_length = 0;
LONG result =
GetPackagePathByFullName(package_full_name, &path_length, nullptr);
if (result != ERROR_INSUFFICIENT_BUFFER) {
return base::FilePath();
}
std::wstring path_buffer;
result = GetPackagePathByFullName(package_full_name, &path_length,
base::WriteInto(&path_buffer, path_length));
if (result != ERROR_SUCCESS) {
return base::FilePath();
}
return base::FilePath(path_buffer);
}
class PlatformFunctionsWin {
public:
~PlatformFunctionsWin() = delete;
PlatformFunctionsWin(const PlatformFunctionsWin&) = delete;
PlatformFunctionsWin& operator=(const PlatformFunctionsWin&) = delete;
static PlatformFunctionsWin* GetInstance() {
static base::NoDestructor<PlatformFunctionsWin> instance;
return instance.get();
}
std::wstring TryCreatePackageDependency(
base::wcstring_view package_family_name,
PACKAGE_VERSION min_version,
PackageDependencyLifetimeKind lifetime_kind,
const wchar_t* lifetime_artifact) {
ScopedWcharType package_dependency_id;
HRESULT hr = try_create_package_dependency_proc_(
nullptr, package_family_name.c_str(), min_version,
PackageDependencyProcessorArchitectures_None, lifetime_kind,
lifetime_artifact, CreatePackageDependencyOptions_None,
ScopedWcharType::Receiver(package_dependency_id).get());
if (FAILED(hr)) {
base::UmaHistogramSparse(
"WebNN.ORT.TryCreatePackageDependency.ErrorResult", hr);
return {};
}
return package_dependency_id.get();
}
base::FilePath AddPackageDependency(base::wcstring_view dependency_id) {
PACKAGEDEPENDENCY_CONTEXT context{};
ScopedWcharType package_full_name;
HRESULT hr = add_package_dependency_proc_(
dependency_id.c_str(), 0,
AddPackageDependencyOptions_PrependIfRankCollision, &context,
ScopedWcharType::Receiver(package_full_name).get());
if (FAILED(hr)) {
base::UmaHistogramSparse("WebNN.ORT.AddPackageDependency.ErrorResult",
hr);
return base::FilePath();
}
return GetPackagePath(package_full_name.get());
}
bool DeletePackageDependency(base::wcstring_view dependency_id) {
HRESULT hr = delete_package_dependency_proc_(dependency_id.c_str());
if (FAILED(hr)) {
base::UmaHistogramSparse("WebNN.ORT.DeletePackageDependency.ErrorResult",
hr);
return false;
}
return true;
}
private:
friend class base::NoDestructor<PlatformFunctionsWin>;
PlatformFunctionsWin() {
HMODULE kbase = ::GetModuleHandle(L"KernelBase.dll");
try_create_package_dependency_proc_ =
reinterpret_cast<TryCreatePackageDependencyProc>(
::GetProcAddress(kbase, "TryCreatePackageDependency"));
CHECK(try_create_package_dependency_proc_);
add_package_dependency_proc_ = reinterpret_cast<AddPackageDependencyProc>(
::GetProcAddress(kbase, "AddPackageDependency"));
CHECK(add_package_dependency_proc_);
delete_package_dependency_proc_ =
reinterpret_cast<DeletePackageDependencyProc>(
::GetProcAddress(kbase, "DeletePackageDependency"));
CHECK(delete_package_dependency_proc_);
}
using TryCreatePackageDependencyProc =
decltype(::TryCreatePackageDependency)*;
TryCreatePackageDependencyProc try_create_package_dependency_proc_ = nullptr;
using AddPackageDependencyProc = decltype(::AddPackageDependency)*;
AddPackageDependencyProc add_package_dependency_proc_ = nullptr;
using DeletePackageDependencyProc = decltype(::DeletePackageDependency)*;
DeletePackageDependencyProc delete_package_dependency_proc_ = nullptr;
};
}
base::FilePath InitializePackageDependencyForProcess(
base::wcstring_view package_family_name,
PACKAGE_VERSION min_version) {
std::wstring dependency_id =
TryCreatePackageDependencyForProcess(package_family_name, min_version);
if (dependency_id.empty()) {
return base::FilePath();
}
return AddPackageDependency(dependency_id);
}
std::wstring TryCreatePackageDependencyForFilePath(
base::wcstring_view package_family_name,
PACKAGE_VERSION min_version,
const base::FilePath& file_path) {
auto* platform_functions = PlatformFunctionsWin::GetInstance();
return platform_functions->TryCreatePackageDependency(
package_family_name, min_version, PackageDependencyLifetimeKind_FilePath,
file_path.value().c_str());
}
std::wstring TryCreatePackageDependencyForProcess(
base::wcstring_view package_family_name,
PACKAGE_VERSION min_version) {
auto* platform_functions = PlatformFunctionsWin::GetInstance();
return platform_functions->TryCreatePackageDependency(
package_family_name, min_version, PackageDependencyLifetimeKind_Process,
nullptr);
}
base::FilePath AddPackageDependency(base::wcstring_view dependency_id) {
auto* platform_functions = PlatformFunctionsWin::GetInstance();
return platform_functions->AddPackageDependency(dependency_id);
}
bool DeletePackageDependency(base::wcstring_view dependency_id) {
auto* platform_functions = PlatformFunctionsWin::GetInstance();
return platform_functions->DeletePackageDependency(dependency_id);
}
}