#ifdef UNSAFE_BUFFERS_BUILD
#pragma allow_unsafe_buffers
#endif
#include "services/webnn/dml/context_impl_dml.h"
#include <limits>
#include "base/bits.h"
#include "base/check.h"
#include "base/check_is_test.h"
#include "base/compiler_specific.h"
#include "base/containers/span.h"
#include "base/strings/strcat.h"
#include "base/types/expected_macros.h"
#include "gpu/command_buffer/service/shared_image/shared_image_manager.h"
#include "gpu/config/gpu_driver_bug_workaround_type.h"
#include "services/webnn/dml/adapter.h"
#include "services/webnn/dml/command_queue.h"
#include "services/webnn/dml/graph_impl_dml.h"
#include "services/webnn/dml/tensor_impl_dml.h"
#include "services/webnn/dml/utils.h"
#include "services/webnn/error.h"
#include "services/webnn/public/cpp/context_properties.h"
#include "services/webnn/public/cpp/operand_descriptor.h"
#include "services/webnn/public/cpp/supported_data_types.h"
#include "services/webnn/public/cpp/supported_tensors.h"
#include "services/webnn/public/mojom/webnn_tensor.mojom.h"
#include "services/webnn/scoped_sequence.h"
#include "services/webnn/webnn_constant_operand.h"
#include "services/webnn/webnn_context_impl.h"
namespace webnn::dml {
using Microsoft::WRL::ComPtr;
namespace {
ContextImplDml::BackendForTesting* g_backend_for_testing = nullptr;
}
ContextProperties ContextImplDml::GetProperties(
DML_FEATURE_LEVEL feature_level) {
CHECK_GE(feature_level, DML_FEATURE_LEVEL_4_0);
static constexpr SupportedDataTypes kFloat16To32Ints8{
OperandDataType::kFloat16, OperandDataType::kFloat32,
OperandDataType::kInt8, OperandDataType::kUint8};
static constexpr SupportedDataTypes kFloat16To32Ints32{
OperandDataType::kFloat16, OperandDataType::kFloat32,
OperandDataType::kInt32, OperandDataType::kUint32};
static constexpr SupportedDataTypes kFloat16To32Ints8To32{
OperandDataType::kFloat16, OperandDataType::kFloat32,
OperandDataType::kInt8, OperandDataType::kUint8,
OperandDataType::kInt32, OperandDataType::kUint32};
static constexpr SupportedDataTypes kFloat16To32Int8To64{
OperandDataType::kFloat16, OperandDataType::kFloat32,
OperandDataType::kInt8, OperandDataType::kInt32, OperandDataType::kInt64};
static constexpr SupportedDataTypes kInts4To32{
OperandDataType::kInt4, OperandDataType::kUint4,
OperandDataType::kInt8, OperandDataType::kUint8,
OperandDataType::kInt32, OperandDataType::kUint32};
static constexpr SupportedDataTypes kInts8To32{
OperandDataType::kInt8, OperandDataType::kUint8, OperandDataType::kInt32,
OperandDataType::kUint32};
static constexpr SupportedDataTypes kUint8To32{OperandDataType::kUint8,
OperandDataType::kUint32};
static constexpr SupportedDataTypes kGatherScatterIndicesSupportedDataTypes{
OperandDataType::kInt32, OperandDataType::kUint32,
OperandDataType::kInt64, OperandDataType::kUint64};
static constexpr uint64_t kTensorByteLengthLimit =
std::numeric_limits<uint32_t>::max();
static constexpr SupportedRanks kMaxRank = SupportedRanks::UpTo(8);
ContextProperties properties(
InputOperandLayout::kNchw, Resample2DAxes::kAny,
BatchNormalizationAxis::kAny,
kTensorByteLengthLimit,
{{DataTypeConstraint::kAllDataTypesAtLeast8bits, kMaxRank},
{DataTypeConstraint::kAllDataTypesAtLeast8bits, kMaxRank},
{DataTypeConstraint::kAllDataTypesAtLeast8bits, kMaxRank},
{DataTypeConstraint::kInt32To64, kMaxRank},
{DataTypeConstraint::kFloat16To32, kMaxRank},
{DataTypeConstraint::kFloat16To32, kMaxRank},
{DataTypeConstraint::kAllDataTypesAtLeast8bits, kMaxRank},
{kFloat16To32Ints8To32, kMaxRank},
{kFloat16To32Ints8To32, kMaxRank},
{DataTypeConstraint::kFloat16To32, {3, 5}},
{DataTypeConstraint::kFloat16To32, SupportedRanks::Exactly(1)},
{DataTypeConstraint::kFloat16To32, {3, 5}},
{DataTypeConstraint::kFloat16To32, SupportedRanks::Exactly(1)},
{kFloat16To32Ints32, kMaxRank},
{kInts8To32, kMaxRank},
{DataTypeConstraint::kFloat32, kMaxRank},
{kInts8To32, kMaxRank},
{kFloat16To32Ints32, kMaxRank},
{kFloat16To32Ints32, kMaxRank},
{kFloat16To32Ints32, kMaxRank},
{kFloat16To32Ints32, kMaxRank},
{kFloat16To32Ints8To32, kMaxRank},
{kFloat16To32Ints8To32, kMaxRank},
{kFloat16To32Ints8To32, kMaxRank},
{kFloat16To32Ints8To32, kMaxRank},
{kFloat16To32Ints8To32, kMaxRank},
{kFloat16To32Ints8To32, kMaxRank},
{kFloat16To32Ints8To32, kMaxRank},
{kFloat16To32Ints8To32, kMaxRank},
{kUint8To32, kMaxRank},
{kUint8To32, kMaxRank},
{kUint8To32, kMaxRank},
{kUint8To32, kMaxRank},
{kUint8To32, kMaxRank},
{DataTypeConstraint::kFloat16To32, kMaxRank},
{DataTypeConstraint::kFloat16To32, kMaxRank},
kUint8To32,
{DataTypeConstraint::kFloat16To32Int8To32, kMaxRank},
{DataTypeConstraint::kFloat16To32, kMaxRank},
{DataTypeConstraint::kFloat16To32, kMaxRank},
{DataTypeConstraint::kFloat16To32, kMaxRank},
{DataTypeConstraint::kFloat16To32, kMaxRank},
{DataTypeConstraint::kFloat16To32, kMaxRank},
{kFloat16To32Ints8To32, kMaxRank},
{DataTypeConstraint::kFloat16To32, kMaxRank},
{DataTypeConstraint::kFloat16To32Int8To32, kMaxRank},
{DataTypeConstraint::kFloat16To32, kMaxRank},
{DataTypeConstraint::kFloat16To32, kMaxRank},
{DataTypeConstraint::kFloat16To32Int8To32, kMaxRank},
{DataTypeConstraint::kFloat16To32, kMaxRank},
{DataTypeConstraint::kFloat16To32, kMaxRank},
{DataTypeConstraint::kFloat16To32, kMaxRank},
{DataTypeConstraint::kFloat16To32, kMaxRank},
{kFloat16To32Ints8To32, kMaxRank},
{kFloat16To32Ints8To32, kMaxRank},
{kGatherScatterIndicesSupportedDataTypes, kMaxRank},
{kFloat16To32Ints8To32, kMaxRank},
{kGatherScatterIndicesSupportedDataTypes, kMaxRank},
{kFloat16To32Ints8To32, kMaxRank},
{kGatherScatterIndicesSupportedDataTypes, kMaxRank},
{DataTypeConstraint::kFloat16To32, kMaxRank},
{DataTypeConstraint::kFloat16To32, {2, 4}},
{DataTypeConstraint::kFloat16To32, SupportedRanks::UpTo(2)},
{DataTypeConstraint::kFloat16To32, SupportedRanks::Exactly(3)},
{DataTypeConstraint::kFloat16To32, SupportedRanks::Exactly(2)},
{DataTypeConstraint::kFloat16To32, SupportedRanks::Exactly(4)},
{DataTypeConstraint::kFloat16To32, SupportedRanks::Exactly(2)},
{DataTypeConstraint::kFloat16To32, SupportedRanks::Exactly(1)},
{DataTypeConstraint::kFloat16To32, kMaxRank},
{DataTypeConstraint::kFloat16To32, kMaxRank},
{DataTypeConstraint::kFloat16To32, kMaxRank},
{DataTypeConstraint::kFloat16To32, kMaxRank},
{DataTypeConstraint::kFloat16To32, kMaxRank},
{DataTypeConstraint::kFloat16To32, kMaxRank},
{DataTypeConstraint::kFloat16To32, kMaxRank},
{DataTypeConstraint::kFloat16To32, SupportedRanks::Exactly(3)},
{DataTypeConstraint::kFloat16To32, SupportedRanks::Exactly(2)},
{DataTypeConstraint::kFloat16To32, SupportedRanks::Exactly(4)},
{DataTypeConstraint::kFloat16To32, SupportedRanks::Exactly(2)},
{DataTypeConstraint::kFloat16To32, SupportedRanks::Exactly(1)},
{DataTypeConstraint::kFloat16To32, kMaxRank},
{kFloat16To32Ints8To32, kMaxRank},
{DataTypeConstraint::kFloat16To32, {4, 5}},
{DataTypeConstraint::kFloat16To32, {4, 5}},
{kFloat16To32Ints8, {4, 5}},
{DataTypeConstraint::kFloat16To32, kMaxRank},
{DataTypeConstraint::kFloat32, kMaxRank},
{DataTypeConstraint::kInts8, kMaxRank},
{DataTypeConstraint::kFloat16To32, kMaxRank},
{DataTypeConstraint::kFloat16To32, kMaxRank},
{DataTypeConstraint::kFloat16To32, kMaxRank},
{DataTypeConstraint::kFloat16To32, kMaxRank},
{kFloat16To32Ints8To32, kMaxRank},
{DataTypeConstraint::kFloat16To32, kMaxRank},
{kFloat16To32Ints8To32, kMaxRank},
{DataTypeConstraint::kFloat16To32, kMaxRank},
{kFloat16To32Ints32, kMaxRank},
{DataTypeConstraint::kFloat16To32, kMaxRank},
{DataTypeConstraint::kFloat16To32, kMaxRank},
{DataTypeConstraint::kFloat16To32, SupportedRanks::Exactly(4)},
{kFloat16To32Ints8To32, kMaxRank},
{kFloat16To32Ints8To32, kMaxRank},
{kFloat16To32Ints8To32, kMaxRank},
{kGatherScatterIndicesSupportedDataTypes, kMaxRank},
{kFloat16To32Ints8To32, kMaxRank},
{kGatherScatterIndicesSupportedDataTypes, kMaxRank},
{kFloat16To32Ints8To32, kMaxRank},
{DataTypeConstraint::kFloat16To32, kMaxRank},
{kFloat16To32Ints8To32, kMaxRank},
{DataTypeConstraint::kFloat16To32, kMaxRank},
{DataTypeConstraint::kFloat16To32, kMaxRank},
{DataTypeConstraint::kFloat16To32, kMaxRank},
{kFloat16To32Ints8To32, kMaxRank},
{DataTypeConstraint::kFloat16To32, kMaxRank},
{DataTypeConstraint::kAllDataTypesAtLeast8bits, kMaxRank},
{kFloat16To32Ints8To32, kMaxRank},
{kFloat16To32Ints32, {2, 8}},
{DataTypeConstraint::kUint8, kMaxRank},
{kFloat16To32Ints8To32, kMaxRank}});
if (feature_level >= DML_FEATURE_LEVEL_4_1) {
properties.data_type_limits.concat_inputs.data_types =
DataTypeConstraint::kAllDataTypesAtLeast8bits;
properties.data_type_limits.add_input.data_types =
DataTypeConstraint::kFloat16To32Ints32To64;
properties.data_type_limits.sub_input.data_types =
DataTypeConstraint::kFloat16To32Ints32To64;
properties.data_type_limits.mul_input.data_types =
DataTypeConstraint::kFloat16To32Ints32To64;
properties.data_type_limits.equal_input.data_types =
DataTypeConstraint::kAllDataTypesAtLeast8bits;
properties.data_type_limits.greater_input.data_types =
DataTypeConstraint::kAllDataTypesAtLeast8bits;
properties.data_type_limits.greater_or_equal_input.data_types =
DataTypeConstraint::kAllDataTypesAtLeast8bits;
properties.data_type_limits.lesser_input.data_types =
DataTypeConstraint::kAllDataTypesAtLeast8bits;
properties.data_type_limits.lesser_or_equal_input.data_types =
DataTypeConstraint::kAllDataTypesAtLeast8bits;
properties.data_type_limits.abs_input.data_types = kFloat16To32Int8To64;
properties.data_type_limits.identity_input.data_types =
DataTypeConstraint::kAllDataTypesAtLeast8bits;
properties.data_type_limits.expand_input.data_types =
DataTypeConstraint::kAllDataTypesAtLeast8bits;
properties.data_type_limits.gather_input.data_types =
DataTypeConstraint::kAllDataTypesAtLeast8bits;
properties.data_type_limits.gather_elements_input.data_types =
DataTypeConstraint::kAllDataTypesAtLeast8bits;
properties.data_type_limits.gather_nd_input.data_types =
DataTypeConstraint::kAllDataTypesAtLeast8bits;
properties.data_type_limits.not_equal_input.data_types =
DataTypeConstraint::kAllDataTypesAtLeast8bits;
properties.data_type_limits.reshape_input.data_types =
DataTypeConstraint::kAllDataTypesAtLeast8bits;
properties.data_type_limits.reverse_input.data_types =
DataTypeConstraint::kAllDataTypesAtLeast8bits;
properties.data_type_limits.scatter_elements_input.data_types =
DataTypeConstraint::kAllDataTypesAtLeast8bits;
properties.data_type_limits.scatter_nd_input.data_types =
DataTypeConstraint::kAllDataTypesAtLeast8bits;
properties.data_type_limits.scatter_nd_updates.data_types =
DataTypeConstraint::kAllDataTypesAtLeast8bits;
properties.data_type_limits.sign_input.data_types =
DataTypeConstraint::kFloat16To32Int8To64;
properties.data_type_limits.slice_input.data_types =
DataTypeConstraint::kAllDataTypesAtLeast8bits;
properties.data_type_limits.split_input.data_types =
DataTypeConstraint::kAllDataTypesAtLeast8bits;
properties.data_type_limits.transpose_input.data_types =
DataTypeConstraint::kAllDataTypesAtLeast8bits;
properties.data_type_limits.triangular_input.data_types =
DataTypeConstraint::kFloat16To32Ints32To64;
}
if (feature_level >= DML_FEATURE_LEVEL_5_0) {
properties.data_type_limits.clamp_input.data_types =
DataTypeConstraint::kAllDataTypesAtLeast8bits;
properties.data_type_limits.cumulative_sum_input.data_types =
DataTypeConstraint::kFloat16To32Ints32To64;
properties.data_type_limits.max_input.data_types =
DataTypeConstraint::kAllDataTypesAtLeast8bits;
properties.data_type_limits.min_input.data_types =
DataTypeConstraint::kAllDataTypesAtLeast8bits;
properties.data_type_limits.pad_input.data_types =
DataTypeConstraint::kAllDataTypesAtLeast8bits;
properties.data_type_limits.reduce_l1_input.data_types =
DataTypeConstraint::kFloat16To32Ints32To64;
properties.data_type_limits.reduce_max_input.data_types =
DataTypeConstraint::kAllDataTypesAtLeast8bits;
properties.data_type_limits.reduce_min_input.data_types =
DataTypeConstraint::kAllDataTypesAtLeast8bits;
properties.data_type_limits.reduce_sum_input.data_types =
DataTypeConstraint::kFloat16To32Ints32To64;
properties.data_type_limits.reduce_sum_square_input.data_types =
DataTypeConstraint::kFloat16To32Ints32To64;
properties.data_type_limits.where_value.data_types =
DataTypeConstraint::kAllDataTypesAtLeast8bits;
properties.data_type_limits.max_pool2d_input.data_types =
DataTypeConstraint::kAllDataTypesAtLeast8bits;
}
if (feature_level >= DML_FEATURE_LEVEL_5_1) {
properties.data_type_limits.add_input.data_types =
DataTypeConstraint::kAllDataTypesAtLeast8bits;
properties.data_type_limits.sub_input.data_types =
DataTypeConstraint::kAllDataTypesAtLeast8bits;
properties.data_type_limits.mul_input.data_types =
DataTypeConstraint::kAllDataTypesAtLeast8bits;
properties.data_type_limits.div_input.data_types = kFloat16To32Ints8To32;
properties.data_type_limits.prelu_input.data_types =
DataTypeConstraint::kFloat16To32Int8To32;
properties.data_type_limits.relu_input.data_types =
DataTypeConstraint::kFloat16To32Int8To32;
properties.data_type_limits.resample2d_input.ranks =
SupportedRanks::UpTo(4);
properties.data_type_limits.triangular_input.data_types =
DataTypeConstraint::kAllDataTypesAtLeast8bits;
}
if (feature_level >= DML_FEATURE_LEVEL_6_0) {
properties.data_type_limits.div_input.data_types =
DataTypeConstraint::kAllDataTypesAtLeast8bits;
properties.data_type_limits.dequantize_linear_scale.data_types =
DataTypeConstraint::kFloat16To32;
properties.data_type_limits.quantize_linear_input.data_types =
DataTypeConstraint::kFloat16To32;
}
if (feature_level >= DML_FEATURE_LEVEL_6_2) {
properties.data_type_limits.resample2d_input.data_types = kFloat16To32Ints8;
}
if (feature_level >= DML_FEATURE_LEVEL_6_3) {
properties.data_type_limits.input.data_types = SupportedDataTypes::All();
properties.data_type_limits.constant.data_types = SupportedDataTypes::All();
properties.data_type_limits.dequantize_linear_input.data_types = kInts4To32;
properties.data_type_limits.dequantize_linear_zero_point.data_types =
kInts4To32;
properties.data_type_limits.quantize_linear_zero_point.data_types =
DataTypeConstraint::kInts4ToInts8;
}
return properties;
}
ContextImplDml::ContextImplDml(
scoped_refptr<Adapter> adapter,
mojo::PendingReceiver<mojom::WebNNContext> receiver,
base::WeakPtr<WebNNContextProviderImpl> context_provider,
mojom::CreateContextOptionsPtr options,
mojo::ScopedDataPipeConsumerHandle write_tensor_consumer,
mojo::ScopedDataPipeProducerHandle read_tensor_producer,
std::unique_ptr<CommandRecorder> command_recorder,
const gpu::GpuFeatureInfo& gpu_feature_info,
gpu::CommandBufferId command_buffer_id,
std::unique_ptr<ScopedSequence> sequence,
scoped_refptr<gpu::MemoryTracker> memory_tracker,
scoped_refptr<base::SingleThreadTaskRunner> owning_task_runner,
gpu::SharedImageManager* shared_image_manager,
scoped_refptr<base::SingleThreadTaskRunner> main_task_runner)
: WebNNContextImpl(std::move(receiver),
std::move(context_provider),
GetProperties(adapter->max_supported_feature_level()),
std::move(options),
std::move(write_tensor_consumer),
std::move(read_tensor_producer),
command_buffer_id,
std::move(sequence),
std::move(memory_tracker),
std::move(owning_task_runner),
shared_image_manager,
std::move(main_task_runner)),
adapter_(std::move(adapter)),
command_recorder_(std::move(command_recorder)),
gpu_feature_info_(gpu_feature_info) {
CHECK(command_recorder_);
}
ContextImplDml::~ContextImplDml() = default;
base::WeakPtr<WebNNContextImpl> ContextImplDml::AsWeakPtr() {
DCHECK_CALLED_ON_VALID_SEQUENCE(gpu_sequence_checker_);
return weak_factory_.GetWeakPtr();
}
void ContextImplDml::SetBackendForTesting(
BackendForTesting* backend_for_testing) {
g_backend_for_testing = backend_for_testing;
}
void ContextImplDml::CreateGraphImpl(
mojo::PendingAssociatedReceiver<mojom::WebNNGraph> receiver,
mojom::GraphInfoPtr graph_info,
WebNNGraphImpl::ComputeResourceInfo compute_resource_info,
base::flat_map<OperandId, std::unique_ptr<WebNNConstantOperand>>
constant_operands,
base::flat_map<OperandId, WebNNTensorImpl*> constant_tensor_operands,
WebNNContextImpl::CreateGraphImplCallback callback) {
if (g_backend_for_testing) {
g_backend_for_testing->CreateGraphImpl(std::move(receiver), this,
std::move(compute_resource_info),
std::move(callback));
return;
}
GraphImplDml::CreateAndBuild(
std::move(receiver), adapter_, weak_factory_.GetWeakPtr(),
std::move(graph_info), std::move(compute_resource_info),
std::move(constant_operands), std::move(constant_tensor_operands),
std::move(callback),
gpu_feature_info_->IsWorkaroundEnabled(
gpu::DISABLE_DML_META_COMMANDS_FOR_GPU));
}
base::expected<scoped_refptr<WebNNTensorImpl>, mojom::ErrorPtr>
ContextImplDml::CreateTensorImpl(
mojo::PendingAssociatedReceiver<mojom::WebNNTensor> receiver,
mojom::TensorInfoPtr tensor_info) {
if (g_backend_for_testing) {
return g_backend_for_testing->CreateTensorImpl(this, std::move(receiver),
std::move(tensor_info));
}
constexpr uint64_t kDMLBufferAlignment = 4ull;
if (std::numeric_limits<uint64_t>::max() - kDMLBufferAlignment <
static_cast<uint64_t>(tensor_info->descriptor.PackedByteLength())) {
LOG(ERROR) << "[WebNN] Tensor is too large to create.";
return base::unexpected(CreateError(mojom::Error::Code::kUnknownError,
"Failed to create tensor."));
}
const uint64_t aligned_buffer_byte_size = base::bits::AlignUp(
static_cast<uint64_t>(tensor_info->descriptor.PackedByteLength()),
kDMLBufferAlignment);
HRESULT hr = S_OK;
ComPtr<ID3D12Resource> buffer;
if (adapter_->IsUMA()) {
if (tensor_info->usage.Has(MLTensorUsageFlags::kWrite) ||
tensor_info->usage.Has(MLTensorUsageFlags::kGraphConstant)) {
hr = CreateCustomUploadBuffer(
adapter_->d3d12_device(), aligned_buffer_byte_size,
L"WebNN_Custom_Upload_Buffer_External", buffer);
} else if (tensor_info->usage.Has(MLTensorUsageFlags::kRead)) {
hr = CreateCustomReadbackBuffer(
adapter_->d3d12_device(), aligned_buffer_byte_size,
L"WebNN_Custom_Readback_Buffer_External", buffer);
} else {
hr = CreateDefaultBuffer(adapter_->d3d12_device(),
aligned_buffer_byte_size,
L"WebNN_Default_Buffer_External", buffer);
}
} else {
hr = CreateDefaultBuffer(adapter_->d3d12_device(), aligned_buffer_byte_size,
L"WebNN_Default_Buffer_External", buffer);
}
if (FAILED(hr)) {
HandleContextLostOrCrash("Failed to create the external buffer.", hr);
return base::unexpected(CreateError(mojom::Error::Code::kUnknownError,
"Failed to create tensor."));
}
return base::MakeRefCounted<TensorImplDml>(std::move(receiver),
std::move(buffer), AsWeakPtr(),
std::move(tensor_info));
}
base::expected<scoped_refptr<WebNNTensorImpl>, mojom::ErrorPtr>
ContextImplDml::CreateTensorFromSharedImageImpl(
mojo::PendingAssociatedReceiver<mojom::WebNNTensor> receiver,
mojom::TensorInfoPtr tensor_info,
WebNNTensorImpl::RepresentationPtr representation) {
if (representation->GetD3D12Buffer()->GetDesc().Width !=
base::bits::AlignUp(
static_cast<uint64_t>(tensor_info->descriptor.PackedByteLength()),
4ull)) {
LOG(ERROR) << "[WebNN] Tensor size mismatched for mailbox.";
return base::unexpected(CreateError(mojom::Error::Code::kUnknownError,
"Failed to create tensor."));
}
return base::MakeRefCounted<TensorImplDml>(
std::move(receiver), std::move(representation), AsWeakPtr(),
std::move(tensor_info));
}
void ContextImplDml::ReadTensor(
TensorImplDml* src_tensor,
mojom::WebNNTensor::ReadTensorCallback callback) {
const size_t src_tensor_size = src_tensor->PackedByteLength();
HRESULT hr = S_OK;
const bool is_uma_mapping_allowed =
!src_tensor->usage().Has(MLTensorUsageFlags::kWebGpuInterop);
if (is_uma_mapping_allowed && adapter_->IsUMA() &&
adapter_->command_queue()->GetCompletedValue() >=
src_tensor->last_submission_fence_value()) {
ContextImplDml::OnReadbackComplete(src_tensor->buffer(), src_tensor_size,
std::move(callback), hr);
return;
}
ComPtr<ID3D12Resource> download_buffer;
hr = CreateReadbackBuffer(adapter_->d3d12_device(),
static_cast<uint64_t>(src_tensor_size),
L"WebNN_Readback_Buffer", download_buffer);
if (FAILED(hr)) {
std::move(callback).Run(ToError<mojom::ReadTensorResult>(
mojom::Error::Code::kUnknownError, "Failed to read tensor."));
HandleContextLostOrCrash("Failed to create the download buffer.", hr);
return;
}
hr = StartRecordingIfNecessary();
if (FAILED(hr)) {
std::move(callback).Run(ToError<mojom::ReadTensorResult>(
mojom::Error::Code::kUnknownError, "Failed to read tensor."));
HandleRecordingError("Failed to start recording.", hr);
return;
}
command_recorder_->ReadbackTensorWithBarrier(download_buffer, src_tensor,
src_tensor_size);
hr = command_recorder_->CloseAndExecute();
if (FAILED(hr)) {
std::move(callback).Run(ToError<mojom::ReadTensorResult>(
mojom::Error::Code::kUnknownError, "Failed to read tensor."));
HandleRecordingError("Failed to close and execute the command list.", hr);
return;
}
adapter_->command_queue()->WaitAsync(base::BindOnce(
&ContextImplDml::OnReadbackComplete, weak_factory_.GetWeakPtr(),
std::move(download_buffer), src_tensor_size, std::move(callback)));
}
void ContextImplDml::OnReadbackComplete(
ComPtr<ID3D12Resource> download_buffer,
size_t read_byte_size,
mojom::WebNNTensor::ReadTensorCallback callback,
HRESULT hr) {
if (FAILED(hr)) {
std::move(callback).Run(ToError<mojom::ReadTensorResult>(
mojom::Error::Code::kUnknownError, "Failed to read tensor."));
HandleRecordingError("Failed to download the buffer.", hr);
return;
}
CHECK(download_buffer);
void* mapped_download_data = nullptr;
hr = download_buffer->Map(0, nullptr, &mapped_download_data);
if (FAILED(hr)) {
std::move(callback).Run(ToError<mojom::ReadTensorResult>(
mojom::Error::Code::kUnknownError, "Failed to read tensor."));
HandleContextLostOrCrash("Failed to map the download buffer.", hr);
return;
}
mojo_base::BigBuffer dst_buffer = WriteDataToDataPipeOrBigBuffer(base::span(
static_cast<const uint8_t*>(mapped_download_data), read_byte_size));
download_buffer->Unmap(0, nullptr);
std::move(callback).Run(
mojom::ReadTensorResult::NewBuffer(std::move(dst_buffer)));
}
void ContextImplDml::WriteTensor(TensorImplDml* dst_tensor,
mojo_base::BigBuffer src_buffer) {
HRESULT hr = S_OK;
ComPtr<ID3D12Resource> buffer_to_map = dst_tensor->buffer();
const bool is_uma_mapping_allowed =
!dst_tensor->usage().Has(MLTensorUsageFlags::kWebGpuInterop);
if (!is_uma_mapping_allowed || !adapter_->IsUMA() ||
adapter_->command_queue()->GetCompletedValue() <
dst_tensor->last_submission_fence_value()) {
hr = CreateUploadBuffer(adapter_->d3d12_device(),
dst_tensor->PackedByteLength(),
L"WebNN_Upload_Buffer", buffer_to_map);
if (FAILED(hr)) {
HandleContextLostOrCrash("Failed to create the upload buffer.", hr);
return;
}
}
CHECK(buffer_to_map);
void* mapped_buffer_data = nullptr;
hr = buffer_to_map->Map(0, nullptr, &mapped_buffer_data);
if (FAILED(hr)) {
HandleContextLostOrCrash("Failed to map the buffer.", hr);
return;
}
CHECK(mapped_buffer_data);
ReadDataFromBigBufferOrDataPipe(
std::move(src_buffer),
UNSAFE_BUFFERS(base::span(static_cast<uint8_t*>(mapped_buffer_data),
dst_tensor->PackedByteLength())));
buffer_to_map->Unmap(0, nullptr);
if (dst_tensor->buffer() != buffer_to_map.Get()) {
hr = StartRecordingIfNecessary();
if (FAILED(hr)) {
HandleRecordingError("Failed to start recording.", hr);
return;
}
command_recorder_->UploadTensorWithBarrier(
dst_tensor, std::move(buffer_to_map), dst_tensor->PackedByteLength());
hr = command_recorder_->CloseAndExecute();
if (FAILED(hr)) {
HandleRecordingError("Failed to close and execute the command list.", hr);
return;
}
adapter_->command_queue()->WaitAsync(base::BindOnce(
&ContextImplDml::OnUploadComplete, weak_factory_.GetWeakPtr()));
}
}
void ContextImplDml::OnUploadComplete(HRESULT hr) {
if (FAILED(hr)) {
HandleRecordingError("Failed to upload the buffer.", hr);
return;
}
}
HRESULT ContextImplDml::StartRecordingIfNecessary() {
if (!command_recorder_) {
ASSIGN_OR_RETURN(command_recorder_,
CommandRecorder::Create(adapter_->command_queue(),
adapter_->dml_device()));
}
CHECK(command_recorder_);
if (command_recorder_->IsOpen()) {
return S_OK;
}
RETURN_IF_FAILED(command_recorder_->Open());
CHECK(command_recorder_->IsOpen());
return S_OK;
}
void ContextImplDml::HandleRecordingError(std::string_view error_message,
HRESULT hr) {
command_recorder_.reset();
HandleContextLostOrCrash(error_message, hr);
}
void ContextImplDml::HandleContextLostOrCrash(std::string_view message_for_log,
HRESULT hr) {
LOG(ERROR) << "[WebNN] " << message_for_log << " "
<< logging::SystemErrorCodeToString(hr);
HRESULT device_removed_reason =
adapter_->d3d12_device()->GetDeviceRemovedReason();
if (FAILED(device_removed_reason)) {
LOG(ERROR) << "[WebNN] Device Removed Reason: "
<< logging::SystemErrorCodeToString(device_removed_reason);
DestroyAllContextsAndKillGpuProcess("device removed.");
return;
}
std::string_view message_for_promise;
switch (hr) {
case E_OUTOFMEMORY:
message_for_promise = "out of memory.";
break;
case DXGI_ERROR_DEVICE_RESET:
message_for_promise = "device reset.";
break;
default:
message_for_promise = "internal error.";
}
OnLost(base::StrCat({"WebNN context is lost due to ", message_for_promise}));
CHECK(hr == E_OUTOFMEMORY || hr == DXGI_ERROR_DEVICE_RESET);
}
CommandQueue* ContextImplDml::GetCommandQueue() const {
return adapter_->command_queue();
}
void ContextImplDml::RemoveDeviceForTesting() {
CHECK_IS_TEST();
ComPtr<ID3D12Device5> d3d12_device_5;
CHECK_EQ(
adapter_->d3d12_device()->QueryInterface(IID_PPV_ARGS(&d3d12_device_5)),
S_OK);
d3d12_device_5->RemoveDevice();
}
}