#include "services/webnn/webnn_context_provider_impl.h"
#include <memory>
#include <utility>
#include "base/feature_list.h"
#include "base/metrics/histogram_functions.h"
#include "base/task/sequenced_task_runner.h"
#include "base/task/thread_pool.h"
#include "gpu/command_buffer/service/scheduler.h"
#include "gpu/command_buffer/service/shared_image/shared_image_manager.h"
#include "mojo/public/cpp/bindings/self_owned_receiver.h"
#include "services/webnn/buildflags.h"
#include "services/webnn/error.h"
#include "services/webnn/public/cpp/context_properties.h"
#include "services/webnn/public/cpp/webnn_trace.h"
#include "services/webnn/public/mojom/features.mojom.h"
#include "services/webnn/public/mojom/webnn_context_provider.mojom.h"
#include "services/webnn/public/mojom/webnn_error.mojom.h"
#include "services/webnn/scoped_sequence.h"
#include "services/webnn/webnn_context_impl.h"
#if BUILDFLAG(IS_WIN)
#include <string>
#include "base/types/expected_macros.h"
#include "gpu/config/gpu_driver_bug_workaround_type.h"
#include "services/webnn/dml/context_provider_dml.h"
#include "services/webnn/ort/context_impl_ort.h"
#include "services/webnn/ort/context_provider_ort.h"
#include "services/webnn/ort/environment.h"
#include "services/webnn/ort/ort_session_options.h"
#endif
#if BUILDFLAG(IS_MAC)
#include "base/mac/mac_util.h"
#endif
#if BUILDFLAG(IS_APPLE)
#include "services/webnn/coreml/context_impl_coreml.h"
#endif
#if BUILDFLAG(WEBNN_USE_TFLITE)
#include "services/webnn/tflite/context_impl_tflite.h"
#endif
namespace webnn {
namespace {
BASE_FEATURE(kWebNNUseDataPipe, base::FEATURE_ENABLED_BY_DEFAULT);
BASE_FEATURE(kWebNNAllowMultipleThreads, base::FEATURE_ENABLED_BY_DEFAULT);
WebNNContextProviderImpl::BackendForTesting* g_backend_for_testing = nullptr;
using webnn::mojom::CreateContextOptionsPtr;
using webnn::mojom::WebNNContextProvider;
enum class DeviceTypeUma {
kCpu = 0,
kGpu = 1,
kNpu = 2,
kMaxValue = kNpu,
};
void RecordDeviceType(const mojom::Device device) {
DeviceTypeUma uma_value;
switch (device) {
case mojom::Device::kCpu:
uma_value = DeviceTypeUma::kCpu;
break;
case mojom::Device::kGpu:
uma_value = DeviceTypeUma::kGpu;
break;
case mojom::Device::kNpu:
uma_value = DeviceTypeUma::kNpu;
break;
}
base::UmaHistogramEnumeration("WebNN.DeviceType", uma_value);
}
}
WebNNContextProviderImpl::WebNNContextProviderImpl(
scoped_refptr<gpu::SharedContextState> shared_context_state,
gpu::GpuFeatureInfo gpu_feature_info,
gpu::GPUInfo gpu_info,
gpu::SharedImageManager* shared_image_manager,
LoseAllContextsCallback lose_all_contexts_callback,
scoped_refptr<base::SingleThreadTaskRunner> main_thread_task_runner,
gpu::Scheduler* scheduler,
int32_t client_id,
mojo::SharedRemote<viz::mojom::GpuHost> gpu_host)
: shared_context_state_(std::move(shared_context_state)),
gpu_feature_info_(std::move(gpu_feature_info)),
gpu_info_(std::move(gpu_info)),
shared_image_manager_(shared_image_manager),
lose_all_contexts_callback_(std::move(lose_all_contexts_callback)),
scheduler_(scheduler),
main_thread_task_runner_(std::move(main_thread_task_runner)),
client_id_(client_id),
gpu_host_(std::move(gpu_host)) {
CHECK_NE(scheduler_, nullptr);
CHECK_NE(main_thread_task_runner_, nullptr);
DCHECK(main_thread_task_runner_->BelongsToCurrentThread());
if (shared_context_state_) {
memory_tracker_ = shared_context_state_->memory_tracker();
}
CHECK(gpu_host_.is_bound());
}
WebNNContextProviderImpl::~WebNNContextProviderImpl() {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
}
std::unique_ptr<WebNNContextProviderImpl> WebNNContextProviderImpl::Create(
scoped_refptr<gpu::SharedContextState> shared_context_state,
gpu::GpuFeatureInfo gpu_feature_info,
gpu::GPUInfo gpu_info,
gpu::SharedImageManager* shared_image_manager,
LoseAllContextsCallback lose_all_contexts_callback,
scoped_refptr<base::SingleThreadTaskRunner> main_thread_task_runner,
gpu::Scheduler* scheduler,
int32_t client_id,
mojo::SharedRemote<viz::mojom::GpuHost> gpu_host) {
return base::WrapUnique(new WebNNContextProviderImpl(
std::move(shared_context_state), std::move(gpu_feature_info),
std::move(gpu_info), shared_image_manager,
std::move(lose_all_contexts_callback), std::move(main_thread_task_runner),
scheduler, client_id, std::move(gpu_host)));
}
void WebNNContextProviderImpl::BindWebNNContextProvider(
mojo::PendingReceiver<mojom::WebNNContextProvider> receiver) {
provider_receivers_.Add(this, std::move(receiver));
}
void WebNNContextProviderImpl::RemoveWebNNContextImpl(
const blink::WebNNContextToken& handle) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
auto it = context_impls_.find(handle);
CHECK(it != context_impls_.end());
context_impls_.erase(it);
}
#if BUILDFLAG(IS_WIN)
void WebNNContextProviderImpl::DestroyAllContextsAndKillGpuProcess(
const std::string& reason) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
for (const auto& impl : context_impls_) {
impl->OnLost(reason);
}
std::move(lose_all_contexts_callback_).Run();
}
#endif
void WebNNContextProviderImpl::SetBackendForTesting(
BackendForTesting* backend_for_testing) {
g_backend_for_testing = backend_for_testing;
}
void WebNNContextProviderImpl::CreateWebNNContext(
CreateContextOptionsPtr options,
WebNNContextProvider::CreateWebNNContextCallback callback) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
static base::AtomicSequenceNumber g_next_route_id;
const gpu::CommandBufferId command_buffer_id =
gpu::CommandBufferIdFromChannelAndRoute(client_id_,
g_next_route_id.GetNext());
auto sequence = std::make_unique<ScopedSequence>(
*scheduler_, main_thread_task_runner_, command_buffer_id);
ScopedTrace scoped_trace("WebNNContextProviderImpl::CreateWebNNContext");
if (g_backend_for_testing) {
context_impls_.emplace(g_backend_for_testing->CreateWebNNContext(
AsWeakPtr(), std::move(options), command_buffer_id, std::move(sequence),
memory_tracker_, main_thread_task_runner_, shared_image_manager_,
main_thread_task_runner_, std::move(callback)));
return;
}
scoped_refptr<base::SingleThreadTaskRunner> owning_task_runner =
main_thread_task_runner_;
if (base::FeatureList::IsEnabled(kWebNNAllowMultipleThreads)) {
owning_task_runner = base::ThreadPool::CreateSingleThreadTaskRunner(
{base::MayBlock(), base::TaskPriority::USER_VISIBLE,
base::TaskShutdownBehavior::SKIP_ON_SHUTDOWN});
}
WebNNContextImplPtr context_impl(nullptr,
OnTaskRunnerDeleter(owning_task_runner));
mojo::PendingRemote<mojom::WebNNContext> remote;
auto receiver = remote.InitWithNewPipeAndPassReceiver();
RecordDeviceType(options->device);
mojo::ScopedDataPipeProducerHandle write_tensor_producer;
mojo::ScopedDataPipeConsumerHandle write_tensor_consumer;
mojo::ScopedDataPipeProducerHandle read_tensor_producer;
mojo::ScopedDataPipeConsumerHandle read_tensor_consumer;
if (base::FeatureList::IsEnabled(kWebNNUseDataPipe)) {
constexpr base::ByteCount kDataPipeSize = base::MiB(16);
MojoResult result = mojo::CreateDataPipe(
kDataPipeSize.InBytes(), write_tensor_producer, write_tensor_consumer);
if (result != MOJO_RESULT_OK) {
LOG(WARNING) << "Failed to create a mojo data pipe for WriteTensor.";
}
result = mojo::CreateDataPipe(kDataPipeSize.InBytes(), read_tensor_producer,
read_tensor_consumer);
if (result != MOJO_RESULT_OK) {
LOG(WARNING) << "Failed to create a mojo data pipe for ReadTensor.";
}
}
#if BUILDFLAG(IS_WIN)
if (ort::ShouldCreateOrtContext(*options)) {
scoped_trace.AddStep("EnsureWebNNExecutionProvidersReady");
gpu_host_->EnsureWebNNExecutionProvidersReady(base::BindOnce(
&WebNNContextProviderImpl::DidEnsureWebNNExecutionProvidersReady,
AsWeakPtr(), std::move(scoped_trace), std::move(options),
std::move(write_tensor_producer), std::move(write_tensor_consumer),
std::move(read_tensor_producer), std::move(read_tensor_consumer),
command_buffer_id, std::move(sequence), std::move(owning_task_runner),
std::move(receiver), std::move(remote), std::move(callback)));
return;
} else if (dml::ShouldCreateDmlContext(*options)) {
base::expected<WebNNContextImplPtr, mojom::ErrorPtr>
context_creation_results = dml::CreateContextFromOptions(
std::move(options), std::move(write_tensor_consumer),
std::move(read_tensor_producer), gpu_feature_info_, gpu_info_,
shared_context_state_.get(), std::move(receiver), AsWeakPtr(),
command_buffer_id, std::move(sequence), memory_tracker_,
main_thread_task_runner_, shared_image_manager_,
main_thread_task_runner_);
if (!context_creation_results.has_value()) {
std::move(callback).Run(mojom::CreateContextResult::NewError(
std::move(context_creation_results.error())));
return;
}
context_impl = std::move(context_creation_results.value());
}
#endif
#if BUILDFLAG(IS_APPLE)
if (__builtin_available(macOS 14.4, *)) {
if (base::FeatureList::IsEnabled(mojom::features::kWebNNCoreML)
#if BUILDFLAG(IS_MAC)
&& base::mac::GetCPUType() == base::mac::CPUType::kArm
#endif
) {
write_tensor_producer.reset();
write_tensor_consumer.reset();
read_tensor_producer.reset();
read_tensor_consumer.reset();
context_impl = coreml::ContextImplCoreml::Create(
std::move(receiver), AsWeakPtr(), std::move(options),
command_buffer_id, std::move(sequence), memory_tracker_,
main_thread_task_runner_, shared_image_manager_,
main_thread_task_runner_);
}
}
#endif
#if BUILDFLAG(WEBNN_USE_TFLITE)
if (!context_impl) {
CreateTFLiteContext(
std::move(scoped_trace), std::move(options),
std::move(write_tensor_producer), std::move(write_tensor_consumer),
std::move(read_tensor_producer), std::move(read_tensor_consumer),
command_buffer_id, std::move(sequence), std::move(owning_task_runner),
std::move(receiver), std::move(remote), std::move(callback));
return;
}
#endif
OnCreateWebNNContextImpl(
std::move(callback), std::move(remote), std::move(write_tensor_producer),
std::move(read_tensor_consumer), std::move(context_impl));
}
void WebNNContextProviderImpl::OnCreateWebNNContextImpl(
WebNNContextProvider::CreateWebNNContextCallback callback,
mojo::PendingRemote<::webnn::mojom::WebNNContext> remote,
mojo::ScopedDataPipeProducerHandle write_tensor_producer,
mojo::ScopedDataPipeConsumerHandle read_tensor_consumer,
WebNNContextImplPtr context_impl) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
if (!context_impl) {
std::move(callback).Run(ToError<mojom::CreateContextResult>(
mojom::Error::Code::kNotSupportedError,
"WebNN is not supported on this platform."));
LOG(ERROR) << "WebNN is not supported on this platform.";
return;
}
ContextProperties context_properties = context_impl->properties();
const blink::WebNNContextToken& context_handle = context_impl->handle();
context_impls_.emplace(std::move(context_impl));
auto success = mojom::CreateContextSuccess::New(
std::move(remote), std::move(context_properties),
std::move(context_handle), std::move(write_tensor_producer),
std::move(read_tensor_consumer));
std::move(callback).Run(
mojom::CreateContextResult::NewSuccess(std::move(success)));
}
base::optional_ref<WebNNContextImpl>
WebNNContextProviderImpl::GetWebNNContextImplForTesting(
const blink::WebNNContextToken& handle) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
const auto it = context_impls_.find(handle);
if (it == context_impls_.end()) {
mojo::ReportBadMessage(kBadMessageInvalidContext);
return std::nullopt;
}
return it->get();
}
#if BUILDFLAG(WEBNN_USE_TFLITE)
void WebNNContextProviderImpl::CreateTFLiteContext(
ScopedTrace scoped_trace,
mojom::CreateContextOptionsPtr options,
mojo::ScopedDataPipeProducerHandle write_tensor_producer,
mojo::ScopedDataPipeConsumerHandle write_tensor_consumer,
mojo::ScopedDataPipeProducerHandle read_tensor_producer,
mojo::ScopedDataPipeConsumerHandle read_tensor_consumer,
gpu::CommandBufferId command_buffer_id,
std::unique_ptr<ScopedSequence> sequence,
scoped_refptr<base::SingleThreadTaskRunner> task_runner,
mojo::PendingReceiver<mojom::WebNNContext> receiver,
mojo::PendingRemote<mojom::WebNNContext> remote,
CreateWebNNContextCallback callback) {
if (!task_runner->BelongsToCurrentThread()) {
sequence.reset();
sequence = std::make_unique<ScopedSequence>(*scheduler_, task_runner,
command_buffer_id);
scoped_trace.AddStep("Create on sequence");
task_runner->PostTaskAndReplyWithResult(
FROM_HERE,
base::BindOnce(
&tflite::ContextImplTflite::Create, std::move(receiver),
AsWeakPtr(), std::move(options), std::move(write_tensor_consumer),
std::move(read_tensor_producer), command_buffer_id,
std::move(sequence), std::move(memory_tracker_), task_runner,
base::Unretained(shared_image_manager_.get()),
main_thread_task_runner_, std::move(scoped_trace)),
base::BindOnce(&WebNNContextProviderImpl::OnCreateWebNNContextImpl,
AsWeakPtr(), std::move(callback), std::move(remote),
std::move(write_tensor_producer),
std::move(read_tensor_consumer)));
return;
}
WebNNContextImplPtr context_impl = tflite::ContextImplTflite::Create(
std::move(receiver), AsWeakPtr(), std::move(options),
std::move(write_tensor_consumer), std::move(read_tensor_producer),
command_buffer_id, std::move(sequence), memory_tracker_,
std::move(task_runner), shared_image_manager_, main_thread_task_runner_,
std::move(scoped_trace));
OnCreateWebNNContextImpl(
std::move(callback), std::move(remote), std::move(write_tensor_producer),
std::move(read_tensor_consumer), std::move(context_impl));
}
#endif
#if BUILDFLAG(IS_WIN)
void WebNNContextProviderImpl::DidEnsureWebNNExecutionProvidersReady(
ScopedTrace scoped_trace,
mojom::CreateContextOptionsPtr options,
mojo::ScopedDataPipeProducerHandle write_tensor_producer,
mojo::ScopedDataPipeConsumerHandle write_tensor_consumer,
mojo::ScopedDataPipeProducerHandle read_tensor_producer,
mojo::ScopedDataPipeConsumerHandle read_tensor_consumer,
gpu::CommandBufferId command_buffer_id,
std::unique_ptr<ScopedSequence> sequence,
scoped_refptr<base::SingleThreadTaskRunner> task_runner,
mojo::PendingReceiver<mojom::WebNNContext> receiver,
mojo::PendingRemote<mojom::WebNNContext> remote,
CreateWebNNContextCallback callback,
base::flat_map<std::string, mojom::EpPackageInfoPtr> ep_package_info) {
WebNNContextImplPtr context_impl(
nullptr, OnTaskRunnerDeleter(main_thread_task_runner_));
scoped_trace.AddStep("ort::Environment::GetInstance");
base::expected<scoped_refptr<ort::Environment>, std::string>
env_creation_results =
ort::Environment::GetInstance(gpu_info_, ep_package_info);
if (!env_creation_results.has_value()) {
LOG(ERROR) << "[WebNN] Failed to create ONNX Runtime context: "
<< env_creation_results.error();
} else {
mojom::Device device_type = options->device;
if (device_type == mojom::Device::kNpu &&
gpu_feature_info_.IsWorkaroundEnabled(gpu::DISABLE_WEBNN_FOR_NPU)) {
device_type = mojom::Device::kGpu;
LOG(WARNING) << "[WebNN] [WARNING] NPU device is disabled to create "
"ONNX Runtime context. Falling back to GPU.";
}
const EpWorkarounds ep_workarounds =
env_creation_results.value()->GetEpWorkarounds(device_type);
if (!task_runner->BelongsToCurrentThread()) {
sequence.reset();
sequence = std::make_unique<ScopedSequence>(*scheduler_, task_runner,
command_buffer_id);
scoped_trace.AddStep("Create on sequence");
task_runner->PostTaskAndReplyWithResult(
FROM_HERE,
base::BindOnce(
&ort::ContextImplOrt::Create, std::move(receiver), AsWeakPtr(),
ep_workarounds, std::move(options), device_type,
std::move(write_tensor_consumer), std::move(read_tensor_producer),
std::move(env_creation_results.value()), command_buffer_id,
std::move(sequence), std::move(memory_tracker_), task_runner,
base::Unretained(shared_image_manager_.get()),
main_thread_task_runner_, std::move(scoped_trace)),
base::BindOnce(&WebNNContextProviderImpl::OnCreateWebNNContextImpl,
AsWeakPtr(), std::move(callback), std::move(remote),
std::move(write_tensor_producer),
std::move(read_tensor_consumer)));
return;
}
context_impl = ort::ContextImplOrt::Create(
std::move(receiver), AsWeakPtr(), ep_workarounds, std::move(options),
device_type, std::move(write_tensor_consumer),
std::move(read_tensor_producer),
std::move(env_creation_results.value()), command_buffer_id,
std::move(sequence), memory_tracker_, std::move(task_runner),
shared_image_manager_, main_thread_task_runner_,
std::move(scoped_trace));
}
#if BUILDFLAG(WEBNN_USE_TFLITE)
if (!context_impl) {
CreateTFLiteContext(
std::move(scoped_trace), std::move(options),
std::move(write_tensor_producer), std::move(write_tensor_consumer),
std::move(read_tensor_producer), std::move(read_tensor_consumer),
command_buffer_id, std::move(sequence), std::move(task_runner),
std::move(receiver), std::move(remote), std::move(callback));
return;
}
#endif
OnCreateWebNNContextImpl(
std::move(callback), std::move(remote), std::move(write_tensor_producer),
std::move(read_tensor_consumer), std::move(context_impl));
}
#endif
}