#include "services/webnn/dml/context_provider_dml.h"
#include <wrl.h>
#include "base/check_is_test.h"
#include "base/notreached.h"
#include "base/types/expected_macros.h"
#include "gpu/command_buffer/service/shared_context_state.h"
#include "gpu/config/gpu_feature_info.h"
#include "gpu/config/gpu_info.h"
#include "services/webnn/dml/adapter.h"
#include "services/webnn/dml/command_queue.h"
#include "services/webnn/dml/command_recorder.h"
#include "services/webnn/dml/context_impl_dml.h"
#include "services/webnn/dml/utils.h"
#include "services/webnn/error.h"
#include "services/webnn/public/mojom/features.mojom.h"
#include "services/webnn/scoped_sequence.h"
#include "services/webnn/webnn_context_impl.h"
#include "services/webnn/webnn_context_provider_impl.h"
namespace webnn::dml {
namespace {
using Microsoft::WRL::ComPtr;
base::expected<scoped_refptr<Adapter>, mojom::ErrorPtr> GetDmlGpuAdapter(
IDXGIAdapter* dxgi_adapter,
const gpu::GpuFeatureInfo& gpu_feature_info) {
if (gpu_feature_info.IsWorkaroundEnabled(DISABLE_WEBNN_FOR_GPU)) {
return base::unexpected(CreateError(mojom::Error::Code::kNotSupportedError,
"WebNN is blocklisted for GPU."));
}
if (!dxgi_adapter) {
CHECK_IS_TEST();
return Adapter::GetGpuInstanceForTesting();
}
return Adapter::GetGpuInstance(std::move(dxgi_adapter));
}
base::expected<std::unique_ptr<WebNNContextImpl, OnTaskRunnerDeleter>,
mojom::ErrorPtr>
CreateDmlContext(scoped_refptr<Adapter> adapter,
mojo::PendingReceiver<mojom::WebNNContext> receiver,
base::WeakPtr<WebNNContextProviderImpl> context_provider,
mojom::CreateContextOptionsPtr options,
mojo::ScopedDataPipeConsumerHandle write_tensor_consumer,
mojo::ScopedDataPipeProducerHandle read_tensor_producer,
const gpu::GpuFeatureInfo& gpu_feature_info,
gpu::CommandBufferId command_buffer_id,
std::unique_ptr<ScopedSequence> sequence,
scoped_refptr<gpu::MemoryTracker> memory_tracker,
scoped_refptr<base::SingleThreadTaskRunner> owning_task_runner,
gpu::SharedImageManager* shared_image_manager,
scoped_refptr<base::SingleThreadTaskRunner> main_task_runner) {
ASSIGN_OR_RETURN(
auto command_recorder,
CommandRecorder::Create(adapter->command_queue(), adapter->dml_device()),
[](HRESULT hr) {
return CreateError(mojom::Error::Code::kUnknownError,
"Failed to create a CommandRecorder.");
});
auto task_runner = owning_task_runner;
std::unique_ptr<WebNNContextImpl, OnTaskRunnerDeleter> context_impl(
new ContextImplDml(
std::move(adapter), std::move(receiver), std::move(context_provider),
std::move(options), std::move(write_tensor_consumer),
std::move(read_tensor_producer), std::move(command_recorder),
gpu_feature_info, command_buffer_id, std::move(sequence),
std::move(memory_tracker), std::move(owning_task_runner),
shared_image_manager, std::move(main_task_runner)),
OnTaskRunnerDeleter(std::move(task_runner)));
return context_impl;
}
}
bool ShouldCreateDmlContext(const mojom::CreateContextOptions& options) {
if (!base::FeatureList::IsEnabled(mojom::features::kWebNNDirectML)) {
return false;
}
switch (options.device) {
case mojom::Device::kCpu:
return false;
case mojom::Device::kGpu:
case mojom::Device::kNpu:
return true;
}
}
base::expected<std::unique_ptr<WebNNContextImpl, OnTaskRunnerDeleter>,
mojom::ErrorPtr>
CreateContextFromOptions(
mojom::CreateContextOptionsPtr options,
mojo::ScopedDataPipeConsumerHandle write_tensor_consumer,
mojo::ScopedDataPipeProducerHandle read_tensor_producer,
const gpu::GpuFeatureInfo& gpu_feature_info,
const gpu::GPUInfo& gpu_info,
const gpu::SharedContextState* shared_context_state,
mojo::PendingReceiver<mojom::WebNNContext> receiver,
base::WeakPtr<WebNNContextProviderImpl> context_provider,
gpu::CommandBufferId command_buffer_id,
std::unique_ptr<ScopedSequence> sequence,
scoped_refptr<gpu::MemoryTracker> memory_tracker,
scoped_refptr<base::SingleThreadTaskRunner> owning_task_runner,
gpu::SharedImageManager* shared_image_manager,
scoped_refptr<base::SingleThreadTaskRunner> main_task_runner) {
DCHECK(gpu_feature_info.IsInitialized());
if (gpu_feature_info.status_values[gpu::GPU_FEATURE_TYPE_WEBNN] !=
gpu::kGpuFeatureStatusEnabled) {
LOG(ERROR) << "[WebNN] is not compatible with device.";
return base::unexpected(
CreateError(mojom::Error::Code::kNotSupportedError,
"WebNN is not compatible with device."));
}
base::expected<scoped_refptr<Adapter>, mojom::ErrorPtr>
adapter_creation_result;
switch (options->device) {
case mojom::Device::kCpu:
NOTREACHED();
case mojom::Device::kGpu: {
ComPtr<IDXGIAdapter> dxgi_adapter;
if (shared_context_state) {
ComPtr<ID3D11Device> d3d11_device =
shared_context_state->GetD3D11Device();
if (!d3d11_device) {
return base::unexpected(
CreateError(mojom::Error::Code::kNotSupportedError,
"Failed to get D3D11 device."));
}
ComPtr<IDXGIDevice> dxgi_device;
CHECK_EQ(d3d11_device.As(&dxgi_device), S_OK);
CHECK_EQ(dxgi_device->GetAdapter(&dxgi_adapter), S_OK);
CHECK(dxgi_adapter);
}
adapter_creation_result =
GetDmlGpuAdapter(dxgi_adapter.Get(), gpu_feature_info);
break;
}
case mojom::Device::kNpu:
adapter_creation_result =
Adapter::GetNpuInstance(gpu_feature_info, gpu_info);
break;
}
if (!adapter_creation_result.has_value()) {
return base::unexpected(std::move(adapter_creation_result.error()));
}
return CreateDmlContext(
std::move(adapter_creation_result.value()), std::move(receiver),
std::move(context_provider), std::move(options),
std::move(write_tensor_consumer), std::move(read_tensor_producer),
gpu_feature_info, command_buffer_id, std::move(sequence),
std::move(memory_tracker), std::move(owning_task_runner),
shared_image_manager, std::move(main_task_runner));
}
}