#include "services/webnn/dml/command_recorder.h"
#include "base/compiler_specific.h"
#include "base/logging.h"
#include "base/memory/ptr_util.h"
#include "base/notreached.h"
#include "base/numerics/safe_conversions.h"
#include "base/trace_event/trace_event.h"
#include "services/webnn/dml/command_queue.h"
#include "services/webnn/dml/error.h"
#include "services/webnn/dml/tensor_impl_dml.h"
#include "services/webnn/dml/utils.h"
namespace webnn::dml {
namespace {
D3D12_RESOURCE_BARRIER CreateUAVBarrier(ID3D12Resource* resource) {
return {.Type = D3D12_RESOURCE_BARRIER_TYPE_UAV,
.Flags = D3D12_RESOURCE_BARRIER_FLAG_NONE,
.UAV = {.pResource = resource}};
}
}
base::expected<std::unique_ptr<CommandRecorder>, HRESULT>
CommandRecorder::Create(scoped_refptr<CommandQueue> queue,
Microsoft::WRL::ComPtr<IDMLDevice1> dml_device) {
Microsoft::WRL::ComPtr<ID3D12CommandAllocator> command_allocator;
RETURN_UNEXPECTED_IF_FAILED(
GetD3D12Device(dml_device.Get())
->CreateCommandAllocator(D3D12_COMMAND_LIST_TYPE_COMPUTE,
IID_PPV_ARGS(&command_allocator)));
Microsoft::WRL::ComPtr<IDMLCommandRecorder> command_recorder;
RETURN_UNEXPECTED_IF_FAILED(
dml_device->CreateCommandRecorder(IID_PPV_ARGS(&command_recorder)));
Microsoft::WRL::ComPtr<IDMLBindingTable> binding_table;
RETURN_UNEXPECTED_IF_FAILED(
dml_device->CreateBindingTable(nullptr, IID_PPV_ARGS(&binding_table)));
return base::WrapUnique(new CommandRecorder(
std::move(queue), std::move(dml_device), std::move(command_allocator),
std::move(command_recorder), std::move(binding_table)));
}
CommandRecorder::CommandRecorder(
scoped_refptr<CommandQueue> command_queue,
Microsoft::WRL::ComPtr<IDMLDevice1> dml_device,
Microsoft::WRL::ComPtr<ID3D12CommandAllocator> command_allocator,
Microsoft::WRL::ComPtr<IDMLCommandRecorder> command_recorder,
Microsoft::WRL::ComPtr<IDMLBindingTable> binding_table)
: command_queue_(std::move(command_queue)),
dml_device_(std::move(dml_device)),
d3d12_device_(GetD3D12Device(dml_device_.Get())),
command_allocator_(std::move(command_allocator)),
command_recorder_(std::move(command_recorder)),
binding_table_(std::move(binding_table)) {}
CommandRecorder::~CommandRecorder() = default;
HRESULT CommandRecorder::Open() {
CHECK(!is_open_);
if (last_submitted_fence_value_ <= command_queue_->GetCompletedValue()) {
RETURN_IF_FAILED(command_allocator_->Reset());
}
if (!command_list_) {
RETURN_IF_FAILED(d3d12_device_->CreateCommandList(
0, D3D12_COMMAND_LIST_TYPE_COMPUTE, command_allocator_.Get(), nullptr,
IID_PPV_ARGS(&command_list_)));
} else {
RETURN_IF_FAILED(command_list_->Reset(command_allocator_.Get(), nullptr));
}
command_resources_.clear();
command_tensor_impls_.clear();
is_open_ = true;
return S_OK;
}
HRESULT CommandRecorder::CloseAndExecute() {
RETURN_IF_FAILED(Close());
RETURN_IF_FAILED(Execute());
return S_OK;
}
HRESULT CommandRecorder::Close() {
TRACE_EVENT0("gpu", "dml::CommandRecorder::Close");
CHECK(is_open_);
RETURN_IF_FAILED(command_list_->Close());
is_open_ = false;
return S_OK;
}
HRESULT CommandRecorder::Execute() {
CHECK(!is_open_);
if (last_submitted_fence_value_ != UINT64_MAX &&
last_submitted_fence_value_ > command_queue_->GetCompletedValue()) {
RETURN_IF_FAILED(command_queue_->WaitForFence(
command_queue_->submission_fence(), last_submitted_fence_value_));
}
for (auto& [command_buffer, webnn_tensor_impl] : command_tensor_impls_) {
if (webnn_tensor_impl) {
RETURN_IF_FAILED(webnn_tensor_impl->WaitForExternalFenceAndReset(
command_queue_.get()));
}
}
RETURN_IF_FAILED(command_queue_->ExecuteCommandList(command_list_.Get()));
last_submitted_fence_value_ = command_queue_->GetLastFenceValue();
command_queue_->ReferenceUntilCompleted(command_allocator_);
for (auto& resource : command_resources_) {
command_queue_->ReferenceUntilCompleted(resource);
}
for (auto& [command_buffer, webnn_tensor_impl] : command_tensor_impls_) {
if (webnn_tensor_impl) {
webnn_tensor_impl->SetLastSubmissionFenceValue(
last_submitted_fence_value_);
}
}
return S_OK;
}
void CommandRecorder::ResourceBarrier(
base::span<const D3D12_RESOURCE_BARRIER> barriers) {
CHECK(is_open_);
command_list_->ResourceBarrier(base::checked_cast<uint32_t>(barriers.size()),
barriers.data());
}
void CommandRecorder::CopyBufferRegion(
Microsoft::WRL::ComPtr<ID3D12Resource> dst_buffer,
uint64_t dst_offset,
Microsoft::WRL::ComPtr<ID3D12Resource> src_buffer,
uint64_t src_offset,
uint64_t byte_length) {
CHECK(is_open_);
command_list_->CopyBufferRegion(dst_buffer.Get(), dst_offset,
src_buffer.Get(), src_offset, byte_length);
command_resources_.push_back(std::move(dst_buffer));
command_resources_.push_back(std::move(src_buffer));
}
void CommandRecorder::RecordDispatch(IDMLDispatchable* dispatchable) {
TRACE_EVENT0("gpu", "dml::CommandRecorder::RecordDispatch");
command_recorder_->RecordDispatch(command_list_.Get(), dispatchable,
binding_table_.Get());
}
void CommandRecorder::UploadTensorWithBarrier(
TensorImplDml* dst_tensor,
Microsoft::WRL::ComPtr<ID3D12Resource> src_buffer,
size_t buffer_size) {
dml::UploadBufferWithBarrier(this, dst_tensor->buffer(),
std::move(src_buffer), buffer_size);
OnTensorAccessed(dst_tensor);
}
void CommandRecorder::ReadbackTensorWithBarrier(
Microsoft::WRL::ComPtr<ID3D12Resource> dst_buffer,
TensorImplDml* src_tensor,
size_t buffer_size) {
dml::ReadbackBufferWithBarrier(this, std::move(dst_buffer),
src_tensor->buffer(), buffer_size);
OnTensorAccessed(src_tensor);
}
HRESULT CommandRecorder::InitializeOperator(
IDMLCompiledOperator* compiled_operator,
const std::optional<DML_BINDING_DESC>& input_array_binding,
const std::optional<DML_BINDING_DESC>& persistent_resource_binding) {
TRACE_EVENT0("gpu", "dml::CommandRecorder::InitializeOperator");
CHECK(is_open_);
CHECK(compiled_operator);
Microsoft::WRL::ComPtr<IDMLOperatorInitializer> initializer;
IDMLCompiledOperator* compiled_operators[] = {compiled_operator};
RETURN_IF_FAILED(dml_device_->CreateOperatorInitializer(
1, compiled_operators, IID_PPV_ARGS(&initializer)));
DML_BINDING_PROPERTIES initialization_binding_properties =
initializer->GetBindingProperties();
Microsoft::WRL::ComPtr<ID3D12DescriptorHeap> descriptor_heap;
const uint32_t num_descriptors_in_heap =
std::max(1u, initialization_binding_properties.RequiredDescriptorCount);
RETURN_IF_FAILED(CreateDescriptorHeap(
d3d12_device_.Get(), num_descriptors_in_heap,
L"WebNN_Descriptor_Heap_For_Initialization", descriptor_heap));
ID3D12DescriptorHeap* descriptor_heaps[] = {descriptor_heap.Get()};
command_list_->SetDescriptorHeaps( 1,
descriptor_heaps);
DML_BINDING_TABLE_DESC binding_table_desc = {
.Dispatchable = initializer.Get(),
.CPUDescriptorHandle =
descriptor_heap->GetCPUDescriptorHandleForHeapStart(),
.GPUDescriptorHandle =
descriptor_heap->GetGPUDescriptorHandleForHeapStart(),
.SizeInDescriptors =
initialization_binding_properties.RequiredDescriptorCount};
RETURN_IF_FAILED(binding_table_->Reset(&binding_table_desc));
auto temp_resource_size =
initialization_binding_properties.TemporaryResourceSize;
if (temp_resource_size > 0) {
Microsoft::WRL::ComPtr<ID3D12Resource> temp_resource;
RETURN_IF_FAILED(CreateDefaultBuffer(
d3d12_device_.Get(), temp_resource_size,
L"WebNN_Temporary_Buffer_For_Initialization", temp_resource));
DML_BUFFER_BINDING temp_buffer_binding{.Buffer = temp_resource.Get(),
.Offset = 0,
.SizeInBytes = temp_resource_size};
DML_BINDING_DESC temp_binding_desc{.Type = DML_BINDING_TYPE_BUFFER,
.Desc = &temp_buffer_binding};
binding_table_->BindTemporaryResource(&temp_binding_desc);
command_resources_.push_back(std::move(temp_resource));
}
if (input_array_binding.has_value()) {
CHECK_EQ(input_array_binding.value().Type, DML_BINDING_TYPE_BUFFER_ARRAY);
binding_table_->BindInputs( 1,
&input_array_binding.value());
const DML_BUFFER_ARRAY_BINDING* dml_buffer_array_binding =
static_cast<const DML_BUFFER_ARRAY_BINDING*>(
input_array_binding.value().Desc);
for (size_t i = 0; i < dml_buffer_array_binding->BindingCount; ++i) {
ID3D12Resource* buffer =
UNSAFE_TODO(dml_buffer_array_binding->Bindings[i]).Buffer;
if (buffer) {
command_resources_.push_back(buffer);
}
}
}
if (persistent_resource_binding.has_value()) {
CHECK_EQ(persistent_resource_binding.value().Type, DML_BINDING_TYPE_BUFFER);
binding_table_->BindOutputs( 1,
&persistent_resource_binding.value());
ID3D12Resource* persistent_resource =
static_cast<const DML_BUFFER_BINDING*>(
persistent_resource_binding.value().Desc)
->Buffer;
CHECK_NE(persistent_resource, nullptr);
command_resources_.push_back(persistent_resource);
}
RETURN_IF_FAILED(dml_device_->GetDeviceRemovedReason());
RecordDispatch(initializer.Get());
command_resources_.push_back(std::move(initializer));
command_resources_.push_back(std::move(descriptor_heap));
if (persistent_resource_binding.has_value()) {
auto uav = CreateUAVBarrier(nullptr);
command_list_->ResourceBarrier( 1, &uav);
}
return S_OK;
}
HRESULT CommandRecorder::ExecuteOperator(
Microsoft::WRL::ComPtr<IDMLCompiledOperator> compiled_operator,
Microsoft::WRL::ComPtr<ID3D12DescriptorHeap> descriptor_heap,
const std::optional<DML_BINDING_DESC>& persistent_resource_binding,
const std::optional<DML_BINDING_DESC>& temporary_resource_binding) {
TRACE_EVENT0("gpu", "dml::CommandRecorder::ExecuteOperator");
CHECK(is_open_);
CHECK(compiled_operator);
DML_BINDING_PROPERTIES execution_binding_properties =
compiled_operator->GetBindingProperties();
ID3D12DescriptorHeap* descriptor_heaps[] = {descriptor_heap.Get()};
command_list_->SetDescriptorHeaps( 1,
descriptor_heaps);
DML_BINDING_TABLE_DESC binding_table_desc = {
.Dispatchable = compiled_operator.Get(),
.CPUDescriptorHandle =
descriptor_heap->GetCPUDescriptorHandleForHeapStart(),
.GPUDescriptorHandle =
descriptor_heap->GetGPUDescriptorHandleForHeapStart(),
.SizeInDescriptors =
execution_binding_properties.RequiredDescriptorCount};
RETURN_IF_FAILED(binding_table_->Reset(&binding_table_desc));
auto temp_resource_size = execution_binding_properties.TemporaryResourceSize;
if (temp_resource_size > 0) {
CHECK_EQ(temporary_resource_binding.has_value(), true);
CHECK_EQ(temporary_resource_binding.value().Type, DML_BINDING_TYPE_BUFFER);
binding_table_->BindTemporaryResource(&temporary_resource_binding.value());
RETURN_IF_FAILED(dml_device_->GetDeviceRemovedReason());
ID3D12Resource* temporary_resource =
static_cast<const DML_BUFFER_BINDING*>(
temporary_resource_binding.value().Desc)
->Buffer;
CHECK_NE(temporary_resource, nullptr);
command_resources_.push_back(temporary_resource);
}
auto persistent_buffer_size =
execution_binding_properties.PersistentResourceSize;
if (persistent_buffer_size > 0) {
CHECK_EQ(persistent_resource_binding.has_value(), true);
CHECK_EQ(persistent_resource_binding.value().Type, DML_BINDING_TYPE_BUFFER);
binding_table_->BindPersistentResource(
&persistent_resource_binding.value());
RETURN_IF_FAILED(dml_device_->GetDeviceRemovedReason());
ID3D12Resource* persistent_resource =
static_cast<const DML_BUFFER_BINDING*>(
persistent_resource_binding.value().Desc)
->Buffer;
CHECK_NE(persistent_resource, nullptr);
command_resources_.push_back(persistent_resource);
}
RecordDispatch(compiled_operator.Get());
command_resources_.push_back(std::move(compiled_operator));
command_resources_.push_back(std::move(descriptor_heap));
return S_OK;
}
HRESULT CommandRecorder::BindInputs(
base::span<const DML_BINDING_DESC> input_bindings) {
TRACE_EVENT0("gpu", "dml::CommandRecorder::BindInputs");
if (input_bindings.size() > 0) {
binding_table_->BindInputs(
base::checked_cast<uint32_t>(input_bindings.size()),
input_bindings.data());
RETURN_IF_FAILED(dml_device_->GetDeviceRemovedReason());
}
return S_OK;
}
HRESULT CommandRecorder::BindOutputs(
base::span<const DML_BINDING_DESC> output_bindings) {
TRACE_EVENT0("gpu", "dml::CommandRecorder::BindOutputs");
binding_table_->BindOutputs(
base::checked_cast<uint32_t>(output_bindings.size()),
output_bindings.data());
RETURN_IF_FAILED(dml_device_->GetDeviceRemovedReason());
return S_OK;
}
void CommandRecorder::OnTensorAccessed(TensorImplDml* tensor) {
command_tensor_impls_.emplace(tensor->buffer(), tensor->AsWeakPtr());
}
void CommandRecorder::ReferenceCommandResources(
Microsoft::WRL::ComPtr<IUnknown> object) {
command_resources_.push_back(std::move(object));
}
}