#ifndef SERVICES_WEBNN_DML_UTILS_H_
#define SERVICES_WEBNN_DML_UTILS_H_
#include <string>
#include <vector>
#include "base/component_export.h"
#include "base/containers/span.h"
#include "services/webnn/dml/command_recorder.h"
#include "services/webnn/public/mojom/webnn_context_provider.mojom.h"
#include "services/webnn/public/mojom/webnn_error.mojom.h"
#include "third_party/microsoft_dxheaders/include/directml.h"
#include "third_party/microsoft_dxheaders/src/include/directx/d3d12.h"
#include <wrl.h>
namespace webnn::dml {
uint64_t CalculatePhysicalElementCount(base::span<const uint32_t> dimensions,
base::span<const uint32_t> strides = {});
uint64_t CalculateDMLBufferTensorSize(DML_TENSOR_DATA_TYPE data_type,
const std::vector<uint32_t>& dimensions,
const std::vector<uint32_t>& strides);
Microsoft::WRL::ComPtr<ID3D12Device> GetD3D12Device(IDMLDevice1* dml_device);
DML_FEATURE_LEVEL GetMaxSupportedDMLFeatureLevel(IDMLDevice1* dml_device);
std::string_view DMLFeatureLevelToString(DML_FEATURE_LEVEL dml_feature_level);
D3D12_RESOURCE_BARRIER COMPONENT_EXPORT(WEBNN_SERVICE)
CreateTransitionBarrier(ID3D12Resource* resource,
D3D12_RESOURCE_STATES before,
D3D12_RESOURCE_STATES after);
void COMPONENT_EXPORT(WEBNN_SERVICE)
UploadBufferWithBarrier(CommandRecorder* command_recorder,
Microsoft::WRL::ComPtr<ID3D12Resource> dst_buffer,
Microsoft::WRL::ComPtr<ID3D12Resource> src_buffer,
size_t buffer_size);
void COMPONENT_EXPORT(WEBNN_SERVICE) ReadbackBufferWithBarrier(
CommandRecorder* command_recorder,
Microsoft::WRL::ComPtr<ID3D12Resource> readback_buffer,
Microsoft::WRL::ComPtr<ID3D12Resource> default_buffer,
size_t buffer_size);
void COMPONENT_EXPORT(WEBNN_SERVICE)
UploadTensorWithBarrier(CommandRecorder* command_recorder,
TensorImplDml* dst_tensor,
Microsoft::WRL::ComPtr<ID3D12Resource> src_buffer,
size_t buffer_size);
void COMPONENT_EXPORT(WEBNN_SERVICE)
ReadbackTensorWithBarrier(CommandRecorder* command_recorder,
Microsoft::WRL::ComPtr<ID3D12Resource> dst_buffer,
TensorImplDml* src_tensor,
size_t buffer_size);
mojom::ErrorPtr CreateError(mojom::Error::Code error_code,
const std::string& error_message,
std::string_view label = "");
HRESULT COMPONENT_EXPORT(WEBNN_SERVICE)
CreateDefaultBuffer(ID3D12Device* device,
uint64_t size,
const wchar_t* name_for_debugging,
Microsoft::WRL::ComPtr<ID3D12Resource>& resource);
HRESULT COMPONENT_EXPORT(WEBNN_SERVICE)
CreateUploadBuffer(ID3D12Device* device,
uint64_t size,
const wchar_t* name_for_debugging,
Microsoft::WRL::ComPtr<ID3D12Resource>& resource);
HRESULT COMPONENT_EXPORT(WEBNN_SERVICE)
CreateReadbackBuffer(ID3D12Device* device,
uint64_t size,
const wchar_t* name_for_debugging,
Microsoft::WRL::ComPtr<ID3D12Resource>& resource);
HRESULT COMPONENT_EXPORT(WEBNN_SERVICE)
CreateCustomUploadBuffer(ID3D12Device* device,
uint64_t size,
const wchar_t* name_for_debugging,
Microsoft::WRL::ComPtr<ID3D12Resource>& resource);
HRESULT COMPONENT_EXPORT(WEBNN_SERVICE) CreateCustomReadbackBuffer(
ID3D12Device* device,
uint64_t size,
const wchar_t* name_for_debugging,
Microsoft::WRL::ComPtr<ID3D12Resource>& resource);
HRESULT COMPONENT_EXPORT(WEBNN_SERVICE) CreateDescriptorHeap(
ID3D12Device* device,
uint32_t num_descriptors,
const wchar_t* name_for_debugging,
Microsoft::WRL::ComPtr<ID3D12DescriptorHeap>& descriptor_heap);
}
#endif