#include "services/webnn/host/execution_provider_initializer.h"
#include <appmodel.h>
#include <wrl.h>
#include <algorithm>
#include <optional>
#include <string_view>
#include <utility>
#include <vector>
#include "base/command_line.h"
#include "base/compiler_specific.h"
#include "base/containers/queue.h"
#include "base/files/file_path.h"
#include "base/functional/bind.h"
#include "base/functional/concurrent_closures.h"
#include "base/logging.h"
#include "base/metrics/histogram_functions.h"
#include "base/no_destructor.h"
#include "base/sequence_checker.h"
#include "base/strings/strcat.h"
#include "base/strings/stringprintf.h"
#include "base/task/bind_post_task.h"
#include "base/task/thread_pool.h"
#include "base/threading/scoped_blocking_call.h"
#include "base/win/core_winrt_util.h"
#include "base/win/scoped_co_mem.h"
#include "base/win/scoped_hstring.h"
#include "services/webnn/public/cpp/execution_providers_info.h"
#include "services/webnn/public/cpp/platform_functions_win.h"
#include "services/webnn/public/cpp/win_app_runtime_package_info.h"
#include "services/webnn/webnn_switches.h"
#include "third_party/windows_app_sdk_headers/src/inc/abi/winml/Microsoft.Windows.AI.MachineLearning.h"
namespace webnn {
namespace {
namespace abi_winml = ::ABI::Microsoft::Windows::AI::MachineLearning;
using ::ABI::Windows::Foundation::AsyncStatus;
using EnsureReadyAsyncOp =
__FIAsyncOperationWithProgress_2_Microsoft__CWindows__CAI__CMachineLearning__CExecutionProviderReadyResult_double;
using EnsureReadyCompletedHandler =
__FIAsyncOperationWithProgressCompletedHandler_2_Microsoft__CWindows__CAI__CMachineLearning__CExecutionProviderReadyResult_double;
using EpPackageInfoMap = base::flat_map<std::string, mojom::EpPackageInfoPtr>;
enum class ExecutionProviderStatusUma {
kUnknown = 0,
kEpVersionTooLow = 1,
kNotInstalled = 2,
kEnsureReadyFailed = 3,
kReadyForUse = 4,
kMaxValue = kReadyForUse,
};
void RecordEpStatus(std::string_view ep_name,
ExecutionProviderStatusUma status) {
constexpr std::string_view kWebnnHistogramPrefix = "WebNN.ORT.";
constexpr std::string_view kWebnnHistogramSuffix = ".Status";
base::UmaHistogramEnumeration(
base::StrCat({kWebnnHistogramPrefix, ep_name, kWebnnHistogramSuffix}),
status);
}
bool operator<(const PACKAGE_VERSION& a, const PACKAGE_VERSION& b) {
if (a.Major != b.Major) {
return a.Major < b.Major;
}
if (a.Minor != b.Minor) {
return a.Minor < b.Minor;
}
if (a.Build != b.Build) {
return a.Build < b.Build;
}
return a.Revision < b.Revision;
}
std::string VersionToString(const PACKAGE_VERSION& version) {
constexpr std::string_view kPackageVersionFormat = "%u.%u.%u.%u";
return base::StringPrintf(kPackageVersionFormat, version.Major, version.Minor,
version.Build, version.Revision);
}
auto CloneMap(const EpPackageInfoMap& map) {
std::vector<std::pair<std::string, mojom::EpPackageInfoPtr>> cloned_entries;
cloned_entries.reserve(map.size());
std::ranges::for_each(map, [&cloned_entries](const auto& pair) {
cloned_entries.emplace_back(pair.first, pair.second.Clone());
});
return base::flat_map(std::move(cloned_entries));
}
std::string GetProviderName(abi_winml::IExecutionProvider* provider) {
base::win::ScopedHString name(nullptr);
HRESULT hr =
provider->get_Name(base::win::ScopedHString::Receiver(name).get());
CHECK_EQ(hr, S_OK);
return name.GetAsUTF8();
}
std::vector<Microsoft::WRL::ComPtr<abi_winml::IExecutionProvider>>
ActivateCatalogAndGetAvailableEps() {
HRESULT hr = S_OK;
const base::CommandLine* command_line =
base::CommandLine::ForCurrentProcess();
const bool ignore_ep_blocklist =
command_line->HasSwitch(switches::kWebNNOrtIgnoreEpBlocklist);
Microsoft::WRL::ComPtr<abi_winml::IExecutionProviderCatalogStatics>
catalog_statics;
{
base::ScopedBlockingCall scoped_blocking_call(
FROM_HERE, base::BlockingType::MAY_BLOCK);
hr = base::win::RoGetActivationFactory(
base::win::ScopedHString::Create(
RuntimeClass_Microsoft_Windows_AI_MachineLearning_ExecutionProviderCatalog)
.get(),
IID_PPV_ARGS(&catalog_statics));
}
if (FAILED(hr)) {
PLOG(WARNING) << "[WebNN] RoGetActivationFactory() failed.";
return {};
}
Microsoft::WRL::ComPtr<abi_winml::IExecutionProviderCatalog> catalog;
hr = catalog_statics->GetDefault(&catalog);
if (FAILED(hr)) {
PLOG(WARNING) << "[WebNN] catalog_statics->GetDefault() failed.";
return {};
}
base::win::ScopedCoMem<abi_winml::IExecutionProvider*> comem_providers;
uint32_t providers_count = 0;
hr = catalog->FindAllProviders(&providers_count, &comem_providers);
if (FAILED(hr)) {
PLOG(WARNING) << "[WebNN] catalog->FindAllProviders() failed.";
return {};
}
auto provider_span =
UNSAFE_BUFFERS(base::span<abi_winml::IExecutionProvider*>(
comem_providers.get(), providers_count));
std::vector<Microsoft::WRL::ComPtr<abi_winml::IExecutionProvider>> providers;
for (auto provider_ptr : provider_span) {
Microsoft::WRL::ComPtr<abi_winml::IExecutionProvider> provider;
provider.Attach(provider_ptr);
const std::string provider_name = GetProviderName(provider.Get());
const auto known_it = kKnownEPs.find(provider_name);
if (known_it == kKnownEPs.end()) {
continue;
}
if (!ignore_ep_blocklist && !known_it->second.enabled) {
continue;
}
providers.push_back(provider);
}
return providers;
}
std::optional<std::pair<std::string, mojom::EpPackageInfoPtr>>
QueryPackageInfoFromProvider(abi_winml::IExecutionProvider* provider,
EnsureReadyAsyncOp* ensure_op) {
Microsoft::WRL::ComPtr<IAsyncInfo> async_info;
HRESULT hr = ensure_op->QueryInterface(IID_PPV_ARGS(&async_info));
CHECK_EQ(hr, S_OK);
std::string ep_name = GetProviderName(provider);
AsyncStatus status;
hr = async_info->get_Status(&status);
CHECK_EQ(hr, S_OK);
if (status != AsyncStatus::Completed) {
RecordEpStatus(ep_name, ExecutionProviderStatusUma::kUnknown);
LOG(WARNING) << "[WebNN] EnsureReadyAsync() didn't complete for "
<< ep_name;
return std::nullopt;
}
Microsoft::WRL::ComPtr<abi_winml::IExecutionProviderReadyResult> ready_result;
hr = ensure_op->GetResults(&ready_result);
CHECK_EQ(hr, S_OK);
abi_winml::ExecutionProviderReadyResultState ready_state;
hr = ready_result->get_Status(&ready_state);
switch (ready_state) {
case abi_winml::ExecutionProviderReadyResultState_Success: {
base::win::ScopedHString ep_path(nullptr);
hr = provider->get_LibraryPath(
base::win::ScopedHString::Receiver(ep_path).get());
CHECK_EQ(hr, S_OK);
CHECK(ep_path.is_valid());
Microsoft::WRL::ComPtr<ABI::Windows::ApplicationModel::IPackageId>
package_id;
hr = provider->get_PackageId(&package_id);
CHECK_EQ(hr, S_OK);
base::win::ScopedHString family_name(nullptr);
hr = package_id->get_FamilyName(
base::win::ScopedHString::Receiver(family_name).get());
CHECK_EQ(hr, S_OK);
CHECK(family_name.is_valid());
ABI::Windows::ApplicationModel::PackageVersion abi_package_version;
hr = package_id->get_Version(&abi_package_version);
CHECK_EQ(hr, S_OK);
PACKAGE_VERSION package_version = {
.Major = abi_package_version.Major,
.Minor = abi_package_version.Minor,
.Build = abi_package_version.Build,
.Revision = abi_package_version.Revision,
};
const PACKAGE_VERSION& min_package_version =
kKnownEPs.find(ep_name)->second.min_package_version;
if (package_version < min_package_version) {
RecordEpStatus(ep_name, ExecutionProviderStatusUma::kEpVersionTooLow);
LOG(WARNING) << "[WebNN] Found [" << ep_name << "] package version: "
<< VersionToString(package_version)
<< " is lower than the minimum required version: "
<< VersionToString(min_package_version);
return std::nullopt;
}
RecordEpStatus(ep_name, ExecutionProviderStatusUma::kReadyForUse);
return std::make_pair(
std::move(ep_name),
mojom::EpPackageInfo::New(std::wstring(family_name.Get()),
std::move(package_version),
base::FilePath(ep_path.Get())));
}
case abi_winml::ExecutionProviderReadyResultState_Failure: {
RecordEpStatus(ep_name, ExecutionProviderStatusUma::kEnsureReadyFailed);
HRESULT extended_error;
hr = ready_result->get_ExtendedError(&extended_error);
CHECK_EQ(hr, S_OK);
base::win::ScopedHString diagnostic_text(nullptr);
hr = ready_result->get_DiagnosticText(
base::win::ScopedHString::Receiver(diagnostic_text).get());
CHECK_EQ(hr, S_OK);
LOG(WARNING) << "[WebNN] [" << ep_name
<< "] failed to get ready. Extended error: " << std::hex
<< extended_error
<< ", diagnostic text: " << diagnostic_text.GetAsUTF8();
return std::nullopt;
}
case abi_winml::ExecutionProviderReadyResultState_InProgress: {
LOG(FATAL)
<< "[WebNN] [" << ep_name
<< "] is still in progress after EnsureReadyAsync() completed.";
}
}
}
void EnsureExecutionProviderReadyAsync(
Microsoft::WRL::ComPtr<abi_winml::IExecutionProvider> provider,
base::OnceCallback<
void(std::optional<std::pair<std::string, mojom::EpPackageInfoPtr>>)>
callback) {
Microsoft::WRL::ComPtr<EnsureReadyAsyncOp> ensure_op;
HRESULT hr = provider->EnsureReadyAsync(&ensure_op);
if (FAILED(hr)) {
PLOG(WARNING) << "[WebNN] EnsureReadyAsync() failed for "
<< GetProviderName(provider.Get());
std::move(callback).Run(std::nullopt);
return;
}
ensure_op->put_Completed(
Microsoft::WRL::Callback<EnsureReadyCompletedHandler>(
[provider,
callback = base::BindPostTaskToCurrentDefault(std::move(callback))](
EnsureReadyAsyncOp* ensure_op, AsyncStatus status) mutable {
std::move(callback).Run(
QueryPackageInfoFromProvider(provider.Get(), ensure_op));
return S_OK;
})
.Get());
}
class ExecutionProviderInitializer {
public:
static ExecutionProviderInitializer* GetInstance() {
static base::NoDestructor<ExecutionProviderInitializer> instance;
return instance.get();
}
~ExecutionProviderInitializer() = delete;
ExecutionProviderInitializer(const ExecutionProviderInitializer&) = delete;
ExecutionProviderInitializer& operator=(const ExecutionProviderInitializer&) =
delete;
void EnsureExecutionProvidersReady(
base::OnceCallback<void(EpPackageInfoMap)> callback) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
switch (state_) {
case State::kEpCatalogNotActivated: {
if (TryActivateEPCatalog()) {
state_ = State::kEpCatalogActivated;
pending_callbacks_.push(std::move(callback));
} else {
std::move(callback).Run({});
}
return;
}
case State::kEpCatalogActivated: {
pending_callbacks_.push(std::move(callback));
return;
}
case State::kEpsEnsured: {
std::move(callback).Run(CloneMap(ep_package_info_map_));
return;
}
}
}
private:
friend class base::NoDestructor<ExecutionProviderInitializer>;
ExecutionProviderInitializer() {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
}
bool TryActivateEPCatalog() {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
if (InitializePackageDependencyForProcess(kWinAppRuntimePackageFamilyName,
kWinAppRuntimePackageMinVersion)
.empty()) {
return false;
}
base::ThreadPool::PostTaskAndReplyWithResult(
FROM_HERE,
{base::MayBlock(), base::TaskPriority::BEST_EFFORT,
base::TaskShutdownBehavior::CONTINUE_ON_SHUTDOWN},
base::BindOnce(&ActivateCatalogAndGetAvailableEps),
base::BindOnce(&ExecutionProviderInitializer::Initialize,
base::Unretained(this)));
return true;
}
void Initialize(
std::vector<Microsoft::WRL::ComPtr<abi_winml::IExecutionProvider>>
providers) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
base::ConcurrentClosures concurrent_closures;
for (const auto& provider : providers) {
abi_winml::ExecutionProviderReadyState ready_state;
HRESULT hr = provider->get_ReadyState(&ready_state);
CHECK_EQ(hr, S_OK);
std::string ep_name = GetProviderName(provider.Get());
switch (ready_state) {
case abi_winml::ExecutionProviderReadyState_Ready: {
LOG(FATAL)
<< "[WebNN] [" << ep_name
<< "] is already in ready state before `EnsureReadyAsync()` "
"is called.";
}
case abi_winml::ExecutionProviderReadyState_NotReady: {
EnsureExecutionProviderReadyAsync(
provider,
base::BindOnce(
[](base::OnceClosure closure,
std::optional<std::pair<
std::string, mojom::EpPackageInfoPtr>> package_info) {
if (package_info.has_value()) {
auto* instance =
ExecutionProviderInitializer::GetInstance();
instance->AddExecutionProviderPackageInfo(
std::move(*package_info));
}
std::move(closure).Run();
},
concurrent_closures.CreateClosure()));
break;
}
case abi_winml::ExecutionProviderReadyState_NotPresent: {
RecordEpStatus(ep_name, ExecutionProviderStatusUma::kNotInstalled);
EnsureExecutionProviderReadyAsync(
provider, base::BindOnce(
[](std::optional<
std::pair<std::string, mojom::EpPackageInfoPtr>>
package_info) {
if (package_info.has_value()) {
auto* instance =
ExecutionProviderInitializer::GetInstance();
instance->AddExecutionProviderPackageInfo(
std::move(*package_info));
}
}));
break;
}
}
}
std::move(concurrent_closures)
.Done(base::BindOnce(&ExecutionProviderInitializer::OnInitialize,
base::Unretained(this)));
}
void OnInitialize() {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
state_ = State::kEpsEnsured;
while (!pending_callbacks_.empty()) {
std::move(pending_callbacks_.front()).Run(CloneMap(ep_package_info_map_));
pending_callbacks_.pop();
}
}
void AddExecutionProviderPackageInfo(
std::pair<std::string, mojom::EpPackageInfoPtr> package_info) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
ep_package_info_map_.insert(std::move(package_info));
}
EpPackageInfoMap ep_package_info_map_;
base::queue<base::OnceCallback<void(EpPackageInfoMap)>> pending_callbacks_;
enum class State {
kEpCatalogNotActivated,
kEpCatalogActivated,
kEpsEnsured,
};
State state_ = State::kEpCatalogNotActivated;
SEQUENCE_CHECKER(sequence_checker_);
};
}
void EnsureExecutionProvidersReady(
base::OnceCallback<void(EpPackageInfoMap)> callback) {
auto* instance = ExecutionProviderInitializer::GetInstance();
instance->EnsureExecutionProvidersReady(std::move(callback));
}
}